1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9# code and classes for controlling SMT solvers, including 'fast' solvers,
10# which support SMTLIB2 push/pop and are controlled by pipe, and heavyweight
11# 'slow' solvers which are run once per problem on static input files.
12import signal
13solverlist_missing = """
14This tool requires the use of an SMT solver.
15
16This tool searches for the file '.solverlist' in the current directory and
17in every parent directory up to the filesystem root.
18"""
19solverlist_format = """
20The .solverlist format is one solver per line, e.g.
21
22# SONOLAR is the strongest offline solver in our experiments.
23SONOLAR: offline: /home/tsewell/bin/sonolar --input-format=smtlib2
24# CVC4 is useful in online and offline mode.
25CVC4: online: /home/tsewell/bin/cvc4 --incremental --lang smt --tlimit=5000
26CVC4: offline: /home/tsewell/bin/cvc4 --lang smt
27# Z3 is a useful online solver. Use of Z3 in offline mode is not recommended,
28# because it produces incompatible models.
29Z3 4.3: online: /home/tsewell/dev/z3-dist/build/z3 -t:2 -smt2 -in
30# Z3 4.3: offline: /home/tsewell/dev/z3-dist/build/z3 -smt2 -in
31
32N.B. only ONE online solver is needed, so Z3 is redundant in the above.
33
34Each non-comment line is ':' separated, with this pattern:
35name : online/offline/fast/slow : command
36
37The name is used to identify the solver. The second token specifies
38the solver mode. Solvers in "fast" or "online" mode must support all
39interactive SMTLIB2 features including push/pop. With "slow" or "offline" mode
40the solver will be executed once per query, and push/pop will not be used.
41
42The remainder of each line is a shell command that executes the solver in
43SMTLIB2 mode. For online solvers it is typically worth setting a resource
44limit, after which the offline solver will be run.
45
46The first online solver will be used. The offline solvers will be used in
47parallel, by default. The set to be used in parallel can be controlled with
48a strategy line e.g.:
49strategy: SONOLAR all, SONOLAR hyp, CVC4 hyp
50
51This specifies that SONOLAR and CVC4 should both be run on each hypothesis. In
52addition SONOLAR will be applied to try to solve all related hypotheses at
53once, which may be faster than solving them one at a time.
54"""
55
56solverlist_file = ['.solverlist']
57class SolverImpl:
58	def __init__ (self, name, fast, args, timeout):
59		self.fast = fast
60		self.args = args
61		self.timeout = timeout
62		self.origname = name
63		self.mem_mode = None
64		if self.fast:
65			self.name = name + ' (online)'
66		else:
67			self.name = name + ' (offline)'
68
69	def __repr__ (self):
70		return 'SolverImpl (%r, %r, %r, %r)' % (self.name,
71			self.fast, self.args, self.timeout)
72
73def parse_solver (bits):
74	import os
75	import sys
76	mode_set = ['fast', 'slow', 'online', 'offline']
77	if len (bits) < 3 or bits[1].lower () not in mode_set:
78		print 'solver.py: solver list could not be parsed'
79		print '  in %s' % solverlist_file[0]
80		print '  reading %r' % bits
81		print solverlist_format
82		sys.exit (1)
83	name = bits[0]
84	fast = (bits[1].lower () in ['fast', 'online'])
85	args = bits[2].split ()
86	assert os.path.exists (args[0]), (args[0], bits)
87	if not fast:
88		timeout = 6000
89	else:
90		timeout = 30
91	return SolverImpl (name, fast, args, timeout)
92
93def find_solverlist_file ():
94	import os
95	import sys
96	path = os.path.abspath (os.getcwd ())
97	while not os.path.exists (os.path.join (path, '.solverlist')):
98		(parent, _) = os.path.split (path)
99		if parent == path:
100			print "solver.py: '.solverlist' missing"
101			print solverlist_missing
102			print solverlist_format
103			sys.exit (1)
104		path = parent
105	fname = os.path.join (path, '.solverlist')
106	solverlist_file[0] = fname
107	return fname
108
109def get_solver_set ():
110	solvers = []
111	strategy = None
112	for line in open (find_solverlist_file ()):
113		line = line.strip ()
114		if not line or line.startswith ('#'):
115			continue
116		bits = [bit.strip () for bit in line.split (':', 2)]
117		if bits[0] == 'strategy':
118			[_, strat] = bits
119			strategy = parse_strategy (strat)
120		elif bits[0] == 'config':
121			[_, config] = bits
122			assert solvers
123			parse_config_change (config, solvers[-1])
124		else:
125			solvers.append (parse_solver (bits))
126	return (solvers, strategy)
127
128def parse_strategy (strat):
129	solvs = strat.split (',')
130	strategy = []
131	for solv in solvs:
132		bits = solv.split ()
133		if len (bits) != 2 or bits[1] not in ['all', 'hyp']:
134			print "solver.py: strategy element %r" % bits
135			print "found in .solverlist strategy line"
136			print "should be [solvername, 'all' or 'hyp']"
137			sys.exit (1)
138		strategy.append (tuple (bits))
139	return strategy
140
141def parse_config_change (config, solver):
142	for assign in config.split (','):
143		bits = assign.split ('=')
144		assert len (bits) == 2, bits
145		[lhs, rhs] = bits
146		lhs = lhs.strip ().lower ()
147		rhs = rhs.strip ().lower ()
148		if lhs == 'mem_mode':
149			assert rhs in ['8', '32']
150			solver.mem_mode = rhs
151		else:
152			assert not 'config understood', assign
153
154def load_solver_set ():
155	import sys
156	(solvers, strategy) = get_solver_set ()
157	fast_solvers = [sv for sv in solvers if sv.fast]
158	slow_solvers = [sv for sv in solvers if not sv.fast]
159	slow_dict = dict ([(sv.origname, sv) for sv in slow_solvers])
160	if strategy == None:
161		strategy = [(nm, strat) for nm in slow_dict
162			for strat in ['all', 'hyp']]
163	for (nm, strat) in strategy:
164		if nm not in slow_dict:
165			print "solver.py: strategy option %r" % nm
166			print "found in .solverlist strategy line"
167			print "not an offline solver (required for parallel use)"
168			print "(known offline solvers %s)" % slow_dict.keys ()
169			sys.exit (1)
170	strategy = [(slow_dict[nm], strat) for (nm, strat) in strategy]
171	assert fast_solvers, solvers
172	assert slow_solvers, solvers
173	return (fast_solvers[0], slow_solvers[0], strategy, slow_dict.values ())
174
175(fast_solver, slow_solver, strategy, model_strategy) = load_solver_set ()
176
177from syntax import (Expr, fresh_name, builtinTs, true_term, false_term,
178  foldr1, mk_or, boolT, word32T, word8T, mk_implies, Type, get_global_wrapper)
179from target_objects import structs, rodata, sections, trace, printout
180from logic import mk_align_valid_ineq, pvalid_assertion1, pvalid_assertion2
181
182import syntax
183import subprocess
184import sys
185import resource
186import re
187import random
188import time
189import tempfile
190import os
191
192last_solver = [None]
193last_10_models = []
194last_satisfiable_hyps = [None]
195last_hyps = [None]
196last_check_model_state = [None]
197inconsistent_hyps = []
198
199active_solvers = []
200max_active_solvers = [5]
201
202random_name = random.randrange (1, 10 ** 9)
203count = [0]
204
205save_solv_example_time = [-1]
206
207def save_solv_example (solv, last_msgs, comments = []):
208	count[0] += 1
209	name = 'ex_%d_%d' % (random_name, count[0])
210	f = open ('smt_examples/' + name, 'w')
211	for msg in comments:
212		f.write ('; ' + msg + '\n')
213	solv.write_solv_script (f, last_msgs)
214	f.close ()
215
216def write_last_solv_script (solv, fname):
217	f = open (fname, 'w')
218	hyps = last_hyps[0]
219	cmds = ['(assert %s)' % hyp for (hyp, _) in hyps] + ['(check-sat)']
220	solv.write_solv_script (f, cmds)
221	f.close ()
222
223def run_time (elapsed, proc):
224	user = None
225	sys = None
226	try:
227		import psutil
228		ps = psutil.Process (proc.pid)
229		ps_time = ps.cpu_times ()
230		user = ps_time.user
231		sys = ps_time.system
232		if elapsed == None:
233			elapsed = time.time () - ps.create_time ()
234	except ImportError, e:
235		return '(cannot import psutil, cannot time solver)'
236	except Exception, e:
237		pass
238	times = ['%.2fs %s' % (t, msg)
239		for (t, msg) in zip ([elapsed, user, sys],
240			['elapsed', 'user', 'sys'])
241		if t != None]
242	if not times:
243		return '(unknown time)'
244	else:
245		return '(%s)' % ', '.join (times)
246
247def smt_typ (typ):
248	if typ.kind == 'Word':
249		return '(_ BitVec %d)' % typ.num
250	elif typ.kind == 'WordArray':
251		return '(Array (_ BitVec %d) (_ BitVec %d))' % tuple (typ.nums)
252	elif typ.kind == 'TokenWords':
253		return '(Array (_ BitVec %d) (_ BitVec %d))' % (
254			token_smt_typ.num, typ.num)
255	return smt_typ_builtins[typ.name]
256
257token_smt_typ = syntax.word64T
258
259smt_typ_builtins = {'Bool':'Bool', 'Mem':'{MemSort}', 'Dom':'{MemDomSort}',
260	'Token': smt_typ (token_smt_typ)}
261
262smt_typs_omitted = set ([builtinTs['HTD'], builtinTs['PMS']])
263
264smt_ops = dict (syntax.ops_to_smt)
265# these additional smt ops aren't used as keywords in the syntax
266more_smt_ops = {
267	'TokenWordsAccess': 'select', 'TokenWordsUpdate': 'store'
268}
269smt_ops.update (more_smt_ops)
270
271def smt_num (num, bits):
272	if num < 0:
273		return '(bvneg %s)' % smt_num (- num, bits)
274	if bits % 4 == 0:
275		digs = bits / 4
276		rep = '%x' % num
277		prefix = '#x'
278	else:
279		digs = bits
280		rep = '{x:b}'.format (x = num)
281		prefix = '#b'
282	rep = rep[-digs:]
283	rep = ('0' * (digs - len(rep))) + rep
284	assert len (rep) == digs
285	return prefix + rep
286
287def smt_num_t (num, typ):
288	assert typ.kind == 'Word', typ
289	return smt_num (num, typ.num)
290
291def mk_smt_expr (smt_expr, typ):
292	return Expr ('SMTExpr', typ, val = smt_expr)
293
294class EnvMiss (Exception):
295	def __init__ (self, name, typ):
296		self.name = name
297		self.typ = typ
298
299cheat_mem_doms = [True]
300
301tokens = {}
302
303def smt_expr (expr, env, solv):
304	if expr.is_op (['WordCast', 'WordCastSigned']):
305		[v] = expr.vals
306		assert v.typ.kind == 'Word' and expr.typ.kind == 'Word'
307		ex = smt_expr (v, env, solv)
308		if expr.typ == v.typ:
309			return ex
310		elif expr.typ.num < v.typ.num:
311			return '((_ extract %d 0) %s)' % (expr.typ.num - 1, ex)
312		else:
313			if expr.name == 'WordCast':
314				return '((_ zero_extend %d) %s)' % (
315					expr.typ.num - v.typ.num, ex)
316			else:
317				return '((_ sign_extend %d) %s)' % (
318					expr.typ.num - v.typ.num, ex)
319	elif expr.is_op (['ToFloatingPoint', 'ToFloatingPointSigned',
320			'ToFloatingPointUnsigned', 'FloatingPointCast']):
321		ks = [v.typ.kind for v in expr.vals]
322		expected_ks = {'ToFloatingPoint': ['Word'],
323			'ToFloatingPointSigned': ['Builtin', 'Word'],
324			'ToFloatingPointUnsigned': ['Builtin', 'Word'],
325			'FloatingPointCast': ['FloatingPoint']}
326		expected_ks = expected_ks[expr.name]
327		assert ks == expected_ks, (ks, expected_ks)
328		oname = 'to_fp'
329		if expr.name == 'ToFloatingPointUnsigned':
330			expr.name == 'to_fp_unsigned'
331		op = '(_ %s %d %d)' % tuple ([oname + expr.typ.nums])
332		vs = [smt_expr (v, env, solv) for v in expr.vals]
333		return '(%s %s)' % (op, ' '.join (vs))
334	elif expr.is_op (['CountLeadingZeroes', 'WordReverse']):
335		[v] = expr.vals
336		assert expr.typ.kind == 'Word' and expr.typ == v.typ
337		v = smt_expr (v, env, solv)
338		oper = solv.get_smt_derived_oper (expr.name, expr.typ.num)
339		return '(%s %s)' % (oper, v)
340	elif expr.is_op ('CountTrailingZeroes'):
341		[v] = expr.vals
342		expr = syntax.mk_clz (syntax.mk_word_reverse (v))
343		return smt_expr (expr, env, solv)
344	elif expr.is_op (['PValid', 'PGlobalValid',
345			'PWeakValid', 'PArrayValid']):
346		if expr.name == 'PArrayValid':
347			[htd, typ_expr, p, num] = expr.vals
348			num = to_smt_expr (num, env, solv)
349		else:
350			[htd, typ_expr, p] = expr.vals
351		assert typ_expr.kind == 'Type'
352		typ = typ_expr.val
353		if expr.name == 'PGlobalValid':
354			typ = get_global_wrapper (typ)
355		if expr.name == 'PArrayValid':
356			typ = ('Array', typ, num)
357		else:
358			typ = ('Type', typ)
359		assert htd.kind == 'Var'
360		htd_s = env[(htd.name, htd.typ)]
361		p_s = smt_expr (p, env, solv)
362		var = solv.add_pvalids (htd_s, typ, p_s, expr.name)
363		return var
364	elif expr.is_op ('MemDom'):
365		[p, dom] = [smt_expr (e, env, solv) for e in expr.vals]
366		md = '(%s %s %s)' % (smt_ops[expr.name], p, dom)
367		solv.note_mem_dom (p, dom, md)
368		if cheat_mem_doms:
369			return 'true'
370		return md
371	elif expr.is_op ('MemUpdate'):
372		[m, p, v] = expr.vals
373		assert v.typ.kind == 'Word'
374		m_s = smt_expr (m, env, solv)
375		p_s = smt_expr (p, env, solv)
376		v_s = smt_expr (v, env, solv)
377		return smt_expr_memupd (m_s, p_s, v_s, v.typ, solv)
378	elif expr.is_op ('MemAcc'):
379		[m, p] = expr.vals
380		assert expr.typ.kind == 'Word'
381		m_s = smt_expr (m, env, solv)
382		p_s = smt_expr (p, env, solv)
383		return smt_expr_memacc (m_s, p_s, expr.typ, solv)
384	elif expr.is_op ('Equals') and expr.vals[0].typ == builtinTs['Mem']:
385		(x, y) = [smt_expr (e, env, solv) for e in expr.vals]
386		if x[0] == 'SplitMem' or y[0] == 'SplitMem':
387			assert not 'mem equality involving split possible', (
388				x, y, expr)
389		sexp = '(mem-eq %s %s)' % (x, y)
390		solv.note_model_expr (sexp, boolT)
391		return sexp
392	elif expr.is_op ('Equals') and expr.vals[0].typ == word32T:
393		(x, y) = [smt_expr (e, env, solv) for e in expr.vals]
394		sexp = '(word32-eq %s %s)' % (x, y)
395		return sexp
396	elif expr.is_op ('StackEqualsImplies'):
397		[sp1, st1, sp2, st2] = [smt_expr (e, env, solv)
398			for e in expr.vals]
399		if sp1 == sp2 and st1 == st2:
400			return 'true'
401		assert st2[0] == 'SplitMem', (expr.vals, st2)
402		[_, split2, top2, bot2] = st2
403		if split2 != sp2:
404			res = solv.check_hyp_raw ('(= %s %s)' % (split2, sp2))
405			assert res == 'unsat', (split2, sp2, expr.vals)
406		eq = solv.get_stack_eq_implies (split2, top2, st1)
407		return '(and (= %s %s) %s)' % (sp1, sp2, eq)
408	elif expr.is_op ('ImpliesStackEquals'):
409		[sp1, st1, sp2, st2] = expr.vals
410		eq = solv.add_implies_stack_eq (sp1, st1, st2, env)
411		sp1 = smt_expr (sp1, env, solv)
412		sp2 = smt_expr (sp2, env, solv)
413		return '(and (= %s %s) %s)' % (sp1, sp2, eq)
414	elif expr.is_op ('IfThenElse'):
415		(sw, x, y) = [smt_expr (e, env, solv) for e in expr.vals]
416		return smt_ifthenelse (sw, x, y, solv)
417	elif expr.is_op ('HTDUpdate'):
418		var = solv.add_var ('updated_htd', expr.typ)
419		return var
420	elif expr.kind == 'Op':
421		vals = [smt_expr (e, env, solv) for e in expr.vals]
422		if vals:
423			sexp = '(%s %s)' % (smt_ops[expr.name], ' '.join(vals))
424		else:
425			sexp = smt_ops[expr.name]
426		maybe_note_model_expr (sexp, expr.typ, expr.vals, solv)
427		return sexp
428	elif expr.kind == 'Num':
429		return smt_num_t (expr.val, expr.typ)
430	elif expr.kind == 'Var':
431		if (expr.name, expr.typ) not in env:
432			trace ('Env miss for %s in smt_expr' % expr.name)
433			trace ('Environment is %s' % env)
434			raise EnvMiss (expr.name, expr.typ)
435		val = env[(expr.name, expr.typ)]
436		assert val[0] == 'SplitMem' or type(val) == str
437		return val
438	elif expr.kind == 'Invent':
439		var = solv.add_var ('invented', expr.typ)
440		return var
441	elif expr.kind == 'SMTExpr':
442		return expr.val
443	elif expr.kind == 'Token':
444		return solv.get_token (expr.name)
445	else:
446		assert not 'handled expr', expr
447
448def smt_expr_memacc (m, p, typ, solv):
449	if m[0] == 'SplitMem':
450		p = solv.cache_large_expr (p, 'memacc_pointer', syntax.word32T)
451		(_, split, top, bot) = m
452		top_acc = smt_expr_memacc (top, p, typ, solv)
453		bot_acc = smt_expr_memacc (bot, p, typ, solv)
454		return '(ite (bvule %s %s) %s %s)' % (split, p, top_acc, bot_acc)
455	if typ.num in [8, 32, 64]:
456		sexp = '(load-word%d %s %s)' % (typ.num, m, p)
457	else:
458		assert not 'word load type supported', typ
459	solv.note_model_expr (p, syntax.word32T)
460	solv.note_model_expr (sexp, typ)
461	return sexp
462
463def smt_expr_memupd (m, p, v, typ, solv):
464	if m[0] == 'SplitMem':
465		p = solv.cache_large_expr (p, 'memupd_pointer', syntax.word32T)
466		v = solv.cache_large_expr (v, 'memupd_val', typ)
467		(_, split, top, bot) = m
468		memT = syntax.builtinTs['Mem']
469		top = solv.cache_large_expr (top, 'split_mem_top', memT)
470		top_upd = smt_expr_memupd (top, p, v, typ, solv)
471		bot = solv.cache_large_expr (bot, 'split_mem_bot', memT)
472		bot_upd = smt_expr_memupd (bot, p, v, typ, solv)
473		top = '(ite (bvule %s %s) %s %s)' % (split, p, top_upd, top)
474		bot = '(ite (bvule %s %s) %s %s)' % (split, p, bot, bot_upd)
475		return ('SplitMem', split, top, bot)
476	elif typ.num == 8:
477		p = solv.cache_large_expr (p, 'memupd_pointer', syntax.word32T)
478		p_align = '(bvand %s #xfffffffd)' % p
479		solv.note_model_expr (p_align, syntax.word32T)
480		solv.note_model_expr ('(load-word32 %s %s)' % (m, p_align),
481			syntax.word32T)
482		return '(store-word8 %s %s %s)' % (m, p, v)
483	elif typ.num in [32, 64]:
484		solv.note_model_expr ('(load-word%d %s %s)' % (typ.num, m, p),
485			typ)
486		solv.note_model_expr (p, syntax.word32T)
487		return '(store-word%d %s %s %s)' % (typ.num, m, p, v)
488	else:
489		assert not 'MemUpdate word width supported', typ
490
491def smt_ifthenelse (sw, x, y, solv):
492	if x[0] != 'SplitMem' and y[0] != 'SplitMem':
493		return '(ite %s %s %s)' % (sw, x, y)
494	zero = '#x00000000'
495	if x[0] != 'SplitMem':
496		(x_split, x_top, x_bot) = (zero, x, x)
497	else:
498		(_, x_split, x_top, x_bot) = x
499	if y[0] != 'SplitMem':
500		(y_split, y_top, y_bot) = (zero, y, y)
501	else:
502		(_, y_split, y_top, y_bot) = y
503	if x_split != y_split:
504		split = '(ite %s %s %s)' % (sw, x_split, y_split)
505	else:
506		split = x_split
507	return ('SplitMem', split,
508		'(ite %s %s %s)' % (sw, x_top, y_top),
509		'(ite %s %s %s)' % (sw, x_bot, y_bot))
510
511def to_smt_expr (expr, env, solv):
512	if expr.typ == builtinTs['RelWrapper']:
513		vals = [to_smt_expr (v, env, solv) for v in expr.vals]
514		return syntax.adjust_op_vals (expr, vals)
515	s = smt_expr (expr, env, solv)
516	return mk_smt_expr (s, expr.typ)
517
518def typ_representable (typ):
519	return (typ.kind == 'Word' or typ == builtinTs['Bool']
520		or typ == builtinTs['Token'])
521
522def maybe_note_model_expr (sexpr, typ, subexprs, solv):
523	"""note this expression if values of its type can be represented
524	but one of the subexpression values can't be.
525	e.g. note (= x y) where the type of x/y is an SMT array."""
526	if not typ_representable (typ):
527		return
528	if all ([typ_representable (v.typ) for v in subexprs]):
529		return
530	assert solv, (sexpr, typ)
531	solv.note_model_expr (sexpr, typ)
532
533def split_hyp_sexpr (hyp, accum):
534	if hyp[0] == 'and':
535		for h in hyp[1:]:
536			split_hyp_sexpr (h, accum)
537	elif hyp[0] == 'not' and hyp[1][0] == '=>':
538		(_, p, q) = hyp[1]
539		split_hyp_sexpr (p, accum)
540		split_hyp_sexpr (('not', q), accum)
541	elif hyp[0] == 'not' and hyp[1][0] == 'or':
542		for h in hyp[1][1:]:
543			split_hyp_sexpr (('not', h), accum)
544	elif hyp[0] == 'not' and hyp[1][0] == 'not':
545		split_hyp_sexpr (hyp[1][1], accum)
546	elif hyp[:1] + hyp[2:] == ('=>', 'false'):
547		split_hyp_sexpr (('not', hyp[1]), accum)
548	elif hyp[:1] == ('=', ) and ('true' in hyp or 'false' in hyp):
549		(_, p, q) = hyp
550		if q in ['true', 'false']:
551			(p, q) = (q, p)
552		if p == 'true':
553			split_hyp_sexpr (q, accum)
554		else:
555			split_hyp_sexpr (('not', q), accum)
556	else:
557		accum.append (hyp)
558	return accum
559
560def split_hyp (hyp):
561	if (hyp.startswith ('(and ') or hyp.startswith ('(not (=> ')
562			or hyp.startswith ('(not (or ')
563			or hyp.startswith ('(not (not ')):
564		return [flat_s_expression (h) for h in
565			split_hyp_sexpr (parse_s_expression (hyp), [])]
566	else:
567		return [hyp]
568
569mem_word8_preamble = [
570'''(define-fun load-word32 ((m {MemSort}) (p (_ BitVec 32)))
571	(_ BitVec 32)
572(concat (concat (select m (bvadd p #x00000003)) (select m (bvadd p #x00000002)))
573  (concat (select m (bvadd p #x00000001)) (select m p))))
574''',
575'''(define-fun load-word64 ((m {MemSort}) (p (_ BitVec 32)))
576	(_ BitVec 64)
577(bvor ((_ zero_extend 32) (load-word32 m p))
578	(bvshl ((_ zero_extend 32)
579		(load-word32 m (bvadd p #x00000004))) #x0000000000000020)))''',
580'''(define-fun store-word32 ((m {MemSort}) (p (_ BitVec 32))
581	(v (_ BitVec 32))) {MemSort}
582(store (store (store (store m p ((_ extract 7 0) v))
583	(bvadd p #x00000001) ((_ extract 15 8) v))
584	(bvadd p #x00000002) ((_ extract 23 16) v))
585	(bvadd p #x00000003) ((_ extract 31 24) v))
586) ''',
587'''(define-fun store-word64 ((m {MemSort}) (p (_ BitVec 32)) (v (_ BitVec 64)))
588        {MemSort}
589(store-word32 (store-word32 m p ((_ extract 31 0) v))
590	(bvadd p #x00000004) ((_ extract 63 32) v)))''',
591'''(define-fun load-word8 ((m {MemSort}) (p (_ BitVec 32))) (_ BitVec 8)
592(select m p))''',
593'''(define-fun store-word8 ((m {MemSort}) (p (_ BitVec 32)) (v (_ BitVec 8)))
594	{MemSort}
595(store m p v))''',
596'''(define-fun mem-dom ((p (_ BitVec 32)) (d {MemDomSort})) Bool
597(not (= (select d p) #b0)))''',
598'''(define-fun mem-eq ((x {MemSort}) (y {MemSort})) Bool (= x y))''',
599'''(define-fun word32-eq ((x (_ BitVec 32)) (y (_ BitVec 32)))
600    Bool (= x y))''',
601'''(define-fun word2-xor-scramble ((a (_ BitVec 2)) (x (_ BitVec 2))
602   (b (_ BitVec 2)) (c (_ BitVec 2)) (y (_ BitVec 2)) (d (_ BitVec 2))) Bool
603(bvult (bvadd (bvxor a x) b) (bvadd (bvxor c y) d)))''',
604'''(declare-fun unspecified-precond () Bool)''',
605]
606
607mem_word32_preamble = [
608'''(define-fun load-word32 ((m {MemSort}) (p (_ BitVec 32)))
609	(_ BitVec 32)
610(select m ((_ extract 31 2) p)))''',
611'''(define-fun store-word32 ((m {MemSort}) (p (_ BitVec 32)) (v (_ BitVec 32)))
612	{MemSort}
613(store m ((_ extract 31 2) p) v))''',
614'''(define-fun load-word64 ((m {MemSort}) (p (_ BitVec 32)))
615	(_ BitVec 64)
616(bvor ((_ zero_extend 32) (load-word32 m p))
617	(bvshl ((_ zero_extend 32)
618		(load-word32 m (bvadd p #x00000004))) #x0000000000000020)))''',
619'''(define-fun store-word64 ((m {MemSort}) (p (_ BitVec 32)) (v (_ BitVec 64)))
620        {MemSort}
621(store-word32 (store-word32 m p ((_ extract 31 0) v))
622	(bvadd p #x00000004) ((_ extract 63 32) v)))''',
623'''(define-fun word8-shift ((p (_ BitVec 32))) (_ BitVec 32)
624(bvshl ((_ zero_extend 30) ((_ extract 1 0) p)) #x00000003))''',
625'''(define-fun word8-get ((p (_ BitVec 32)) (x (_ BitVec 32))) (_ BitVec 8)
626((_ extract 7 0) (bvlshr x (word8-shift p))))''',
627'''(define-fun load-word8 ((m {MemSort}) (p (_ BitVec 32))) (_ BitVec 8)
628(word8-get p (load-word32 m p)))''',
629'''(define-fun word8-put ((p (_ BitVec 32)) (x (_ BitVec 32)) (y (_ BitVec 8)))
630  (_ BitVec 32) (bvor (bvshl ((_ zero_extend 24) y) (word8-shift p))
631	(bvand x (bvnot (bvshl #x000000FF (word8-shift p))))))''',
632'''(define-fun store-word8 ((m {MemSort}) (p (_ BitVec 32)) (v (_ BitVec 8)))
633	{MemSort}
634(store-word32 m p (word8-put p (load-word32 m p) v)))''',
635'''(define-fun mem-dom ((p (_ BitVec 32)) (d {MemDomSort})) Bool
636(not (= (select d p) #b0)))''',
637'''(define-fun mem-eq ((x {MemSort}) (y {MemSort})) Bool (= x y))''',
638'''(define-fun word32-eq ((x (_ BitVec 32)) (y (_ BitVec 32)))
639    Bool (= x y))''',
640'''(define-fun word2-xor-scramble ((a (_ BitVec 2)) (x (_ BitVec 2))
641   (b (_ BitVec 2)) (c (_ BitVec 2)) (y (_ BitVec 2)) (d (_ BitVec 2))) Bool
642(bvult (bvadd (bvxor a x) b) (bvadd (bvxor c y) d)))''',
643'''(declare-fun unspecified-precond () Bool)'''
644]
645
646word32_smt_convs = {'MemSort': '(Array (_ BitVec 30) (_ BitVec 32))',
647	'MemDomSort': '(Array (_ BitVec 32) (_ BitVec 1))'}
648word8_smt_convs = {'MemSort': '(Array (_ BitVec 32) (_ BitVec 8))',
649	'MemDomSort': '(Array (_ BitVec 32) (_ BitVec 1))'}
650
651def preexec (timeout):
652	def ret ():
653		# setting the session ID on a fork allows us to clean up
654		# the resulting process group, useful if running multiple
655		# solvers in parallel.
656		os.setsid ()
657		if timeout != None:
658			resource.setrlimit(resource.RLIMIT_CPU,
659				(timeout, timeout))
660	return ret
661
662class ConversationProblem (Exception):
663	def __init__ (self, prompt, response):
664		self.prompt = prompt
665		self.response = response
666
667def get_s_expression (stream, prompt):
668	try:
669		return get_s_expression_inner (stream, prompt)
670	except IOError, e:
671		raise ConversationProblem (prompt, 'IOError')
672
673def get_s_expression_inner (stdout, prompt):
674	"""retreives responses from a solver until parens match"""
675	responses = [stdout.readline ().strip ()]
676	if not responses[0].startswith ('('):
677		bits = responses[0].split ()
678		if len (bits) != 1:
679			raise ConversationProblem (prompt, responses[0])
680		return bits[0]
681	lpars = responses[0].count ('(')
682	rpars = responses[0].count (')')
683	emps = 0
684	while rpars < lpars:
685		r = stdout.readline ().strip ()
686		responses.append (r)
687		lpars += r.count ('(')
688		rpars += r.count (')')
689		if r == '':
690			emps += 1
691			if emps >= 3:
692				raise ConversationProblem (prompt, responses)
693		else:
694			emps = 0
695	return parse_s_expressions (responses)
696
697class SolverFailure (Exception):
698	def __init__ (self, msg):
699		self.msg = msg
700
701	def __str__ (self):
702		return 'SolverFailure (%r)' % self.msg
703
704class Solver:
705	def __init__ (self, produce_unsat_cores = False):
706		self.replayable = []
707		self.unsat_cores = produce_unsat_cores
708		self.online_solver = None
709		self.parallel_solvers = {}
710		self.parallel_model_states = {}
711
712		self.names_used = {}
713		self.names_used_order = []
714		self.external_names = {}
715		self.name_ext = ''
716		self.pvalids = {}
717		self.ptrs = {}
718		self.cached_exprs = {}
719		self.defs = {}
720		self.doms = set ()
721		self.model_vars = set ()
722		self.model_exprs = {}
723		self.arbitrary_vars = {}
724		self.stack_eqs = {}
725		self.mem_naming = {}
726		self.tokens = {}
727		self.smt_derived_ops = {}
728
729		self.num_hyps = 0
730		self.last_model_acc_hyps = (None, None)
731
732		self.pvalid_doms = None
733		self.assertions = []
734
735		self.fast_solver = fast_solver
736		self.slow_solver = slow_solver
737		self.strategy = strategy
738		self.model_strategy = model_strategy
739
740		self.add_rodata_def ()
741
742		last_solver[0] = self
743
744	def preamble (self, solver_impl):
745		preamble = []
746		if solver_impl.fast:
747			preamble += ['(set-option :print-success true)']
748		preamble += [ '(set-option :produce-models true)',
749			'(set-logic QF_AUFBV)', ]
750		if self.unsat_cores:
751			preamble += ['(set-option :produce-unsat-cores true)']
752
753		if solver_impl.mem_mode == '8':
754			preamble.extend (mem_word8_preamble)
755		else:
756			preamble.extend (mem_word32_preamble)
757		return preamble
758
759	def startup_solver (self, use_this_solver = None):
760		if self not in active_solvers:
761			active_solvers.append (self)
762			while len (active_solvers) > max_active_solvers[0]:
763				solv = active_solvers.pop (0)
764				solv.close ('active solver limit')
765
766		if use_this_solver:
767			solver = use_this_solver
768		else:
769			solver = self.fast_solver
770		devnull = open (os.devnull, 'w')
771		self.online_solver = subprocess.Popen (solver.args,
772			stdin = subprocess.PIPE, stdout = subprocess.PIPE,
773			stderr = devnull, preexec_fn = preexec (solver.timeout))
774		devnull.close ()
775
776		for msg in self.preamble (solver):
777			self.send (msg, replay=False)
778		for (msg, _) in self.replayable:
779			self.send (msg, replay=False)
780
781	def close (self, reason = '?'):
782		self.close_parallel_solvers (reason = 'self.close (%s)'
783			% reason)
784		self.close_online_solver ()
785
786	def close_online_solver (self):
787		if self.online_solver:
788			self.online_solver.stdin.close()
789			self.online_solver.stdout.close()
790			self.online_solver = None
791
792	def __del__ (self):
793		self.close ('__del__')
794
795	def smt_name (self, name, kind = ('Var', None),
796			ignore_external_names = False):
797		name = name.replace("'", "_").replace("#", "_").replace('"', "_")
798		if not ignore_external_names:
799			name = fresh_name (name, self.external_names)
800		name = fresh_name (name, self.names_used, kind)
801		self.names_used_order.append (name)
802		return name
803
804	def write (self, msg):
805		self.online_solver.stdin.write (msg + '\n')
806		self.online_solver.stdin.flush()
807
808	def send_inner (self, msg, replay = True, is_model = True):
809		if self.online_solver == None:
810			self.startup_solver ()
811
812		msg = msg.format (** word32_smt_convs)
813		try:
814			self.write (msg)
815			response = self.online_solver.stdout.readline().strip()
816		except IOError, e:
817			raise ConversationProblem (msg, 'IOError')
818		if response != 'success':
819			raise ConversationProblem (msg, response)
820
821	def solver_loop (self, attempt):
822		err = None
823		for i in range (5):
824			if self.online_solver == None:
825				self.startup_solver ()
826			try:
827				return attempt ()
828			except ConversationProblem, e:
829				trace ('SMT conversation problem (attempt %d)'
830					% (i + 1))
831				trace ('I sent %s' % repr (e.prompt))
832				trace ('I got %s' % repr (e.response))
833				trace ('restarting solver')
834				self.online_solver = None
835				err = (e.prompt, e.response)
836		trace ('Repeated SMT failure, giving up.')
837		raise ConversationProblem (err[0], err[1])
838
839	def send (self, msg, replay = True, is_model = True):
840		self.solver_loop (lambda: self.send_inner (msg,
841			replay = replay, is_model = is_model))
842		if replay:
843			self.replayable.append ((msg, is_model))
844
845	def get_s_expression (self, prompt):
846		return get_s_expression (self.online_solver.stdout, prompt)
847
848	def prompt_s_expression_inner (self, prompt):
849		try:
850			self.write (prompt)
851			return self.get_s_expression (prompt)
852		except IOError, e:
853			raise ConversationProblem (prompt, 'IOError')
854
855	def prompt_s_expression (self, prompt):
856		return self.solver_loop (lambda:
857			self.prompt_s_expression_inner (prompt))
858
859	def hyps_sat_raw_inner (self, hyps, model, unsat_core,
860			recursion = False):
861		self.send_inner ('(push 1)', replay = False)
862		for hyp in hyps:
863			self.send_inner ('(assert %s)' % hyp, replay = False,
864				is_model = False)
865		response = self.prompt_s_expression_inner ('(check-sat)')
866		if response not in set (['sat', 'unknown', 'unsat', '']):
867			raise ConversationProblem ('(check-sat)', response)
868
869		all_ok = True
870		m = {}
871		ucs = []
872		if response == 'sat' and model:
873			all_ok = self.fetch_model (m)
874		if response == 'unsat' and unsat_core:
875			ucs = self.get_unsat_core ()
876			all_ok = ucs != None
877
878		self.send_inner ('(pop 1)', replay = False)
879
880		return (response, m, ucs, all_ok)
881
882	def add_var (self, name, typ, kind = 'Var',
883			mem_name = None,
884			ignore_external_names = False):
885		if typ in smt_typs_omitted:
886			# skipped. not supported by all solvers
887			name = self.smt_name (name, ('Ghost', typ),
888				ignore_external_names = ignore_external_names)
889			return name
890		name = self.smt_name (name, kind = (kind, typ),
891			ignore_external_names = ignore_external_names)
892		self.send ('(declare-fun %s () %s)' % (name, smt_typ(typ)))
893		if typ_representable (typ) and kind != 'Aux':
894			self.model_vars.add (name)
895		if typ == builtinTs['Mem'] and mem_name != None:
896			if type (mem_name) == str:
897				self.mem_naming[name] = mem_name
898			else:
899				(nm, prev) = mem_name
900				if prev[0] == 'SplitMem':
901					prev = 'SplitMem'
902				prev = parse_s_expression (prev)
903				self.mem_naming[name] = (nm, prev)
904		return name
905
906	def add_var_restr (self, name, typ, mem_name = None):
907		name = self.add_var (name, typ, mem_name = mem_name)
908		return name
909
910	def add_def (self, name, val, env, ignore_external_names = False):
911		kind = 'Var'
912		if val.typ in smt_typs_omitted:
913			kind = 'Ghost'
914		smt = smt_expr (val, env, self)
915		if smt[0] == 'SplitMem':
916			(_, split, top, bot) = smt
917			def add (nm, typ, smt):
918				val = mk_smt_expr (smt, typ)
919				return self.add_def (name + '_' + nm, val, {},
920					ignore_external_names = ignore_external_names)
921			split = add ('split', syntax.word32T, split)
922			top = add ('top', val.typ, top)
923			bot = add ('bot', val.typ, bot)
924			return ('SplitMem', split, top, bot)
925
926		name = self.smt_name (name, kind = (kind, val.typ),
927			ignore_external_names = ignore_external_names)
928		if kind == 'Ghost':
929			# skipped. not supported by all solvers
930			return name
931		if val.kind == 'Var':
932			trace ('WARNING: redef of var %r to name %s' % (val, name))
933
934		typ = smt_typ (val.typ)
935		self.send ('(define-fun %s () %s %s)' % (name, typ, smt))
936
937		self.defs[name] = parse_s_expression (smt)
938		if typ_representable (val.typ):
939			self.model_vars.add (name)
940
941		return name
942
943	def add_rodata_def (self):
944		ro_name = self.smt_name ('rodata', kind = 'Fun')
945		imp_ro_name = self.smt_name ('implies-rodata', kind = 'Fun')
946		assert ro_name == 'rodata', repr (ro_name)
947		assert imp_ro_name == 'implies-rodata', repr (imp_ro_name)
948		[rodata_data, rodata_ranges, rodata_ptrs] = rodata
949		if not rodata_ptrs:
950			assert not rodata_data
951			ro_def = 'true'
952			imp_ro_def = 'true'
953		else:
954			ro_witness = self.add_var ('rodata-witness', word32T)
955			ro_witness_val = self.add_var ('rodata-witness-val', word32T)
956			assert ro_witness == 'rodata-witness'
957			assert ro_witness_val == 'rodata-witness-val'
958			eq_vs = [(smt_num (p, 32), smt_num (v, 32))
959				for (p, v) in rodata_data.iteritems ()]
960			eq_vs.append ((ro_witness, ro_witness_val))
961			eqs = ['(= (load-word32 m %s) %s)' % v for v in eq_vs]
962			ro_def = '(and %s)' % ' \n  '.join (eqs)
963			ro_ineqs = ['(and (bvule %s %s) (bvule %s %s))'
964				% (smt_num (start, 32), ro_witness,
965					ro_witness, smt_num (end, 32))
966				for (start, end) in rodata_ranges]
967			assns = ['(or %s)' % ' '.join (ro_ineqs),
968				'(= (bvand rodata-witness #x00000003) #x00000000)']
969			for assn in assns:
970				self.assert_fact_smt (assn)
971			imp_ro_def = eqs[-1]
972		self.send ('(define-fun rodata ((m %s)) Bool %s)' % (
973			smt_typ (builtinTs['Mem']), ro_def))
974		self.send ('(define-fun implies-rodata ((m %s)) Bool %s)' % (
975			smt_typ (builtinTs['Mem']), imp_ro_def))
976
977	def get_eq_rodata_witness (self, v):
978		# depends on assertion above, should probably fix this
979		ro_witness = mk_smt_expr ('rodata-witness', word32T)
980		return syntax.mk_eq (ro_witness, v)
981
982	def check_hyp_raw (self, hyp, model = None, force_solv = False,
983			hyp_name = None):
984		return self.hyps_sat_raw ([('(not %s)' % hyp, None)],
985			model = model, unsat_core = None,
986			force_solv = force_solv, hyps_name = hyp_name)
987
988	def next_hyp (self, (hyp, tag), hyp_dict):
989		self.num_hyps += 1
990		name = 'hyp%d' % self.num_hyps
991		hyp_dict[name] = tag
992		return '(! %s :named %s)' % (hyp, name)
993
994	def hyps_sat_raw (self, hyps, model = None, unsat_core = None,
995			force_solv = False, recursion = False,
996			slow_solver = None, hyps_name = None):
997		assert self.unsat_cores or unsat_core == None
998
999		hyp_dict = {}
1000		raw_hyps = [(hyp2, tag) for (hyp, tag) in hyps
1001			for hyp2 in split_hyp (hyp)]
1002		last_hyps[0] = list (raw_hyps)
1003		hyps = [self.next_hyp (h, hyp_dict) for h in raw_hyps]
1004		succ = False
1005		solvs_used = []
1006		if hyps_name == None:
1007			hyps_name = 'group of %d hyps' % len (hyps)
1008		trace ('testing %s:' % hyps_name)
1009		if recursion:
1010			trace ('  (recursion)')
1011		else:
1012			for (hyp, _) in raw_hyps:
1013				trace ('  ' + hyp)
1014
1015		if force_solv != 'Slow':
1016			solvs_used.append (self.fast_solver.name)
1017			l = lambda: self.hyps_sat_raw_inner (hyps,
1018                                        model != None, unsat_core != None,
1019					recursion = recursion)
1020			try:
1021				(response, m, ucs, succ) = self.solver_loop (l)
1022			except ConversationProblem, e:
1023				response = 'ConversationProblem'
1024
1025		if succ and m and not recursion:
1026			succ = self.check_model ([h for (h, _) in raw_hyps], m)
1027
1028		if slow_solver == None:
1029			slow_solver = self.slow_solver
1030		if ((not succ or response not in ['sat', 'unsat'])
1031				and slow_solver and force_solv != 'Fast'):
1032			if solvs_used:
1033				trace ('failed to get result from %s'
1034					% solvs_used[0])
1035			trace ('running %s' % slow_solver.name)
1036			self.close_online_solver ()
1037			solvs_used.append (slow_solver.name)
1038			response = self.use_slow_solver (raw_hyps,
1039				model = model, unsat_core = unsat_core,
1040				use_this_solver = slow_solver)
1041		elif not succ:
1042			pass
1043		elif m:
1044			model.clear ()
1045			model.update (m)
1046		elif ucs:
1047			unsat_core.extend (self.get_unsat_core_tags (ucs,
1048				hyp_dict))
1049
1050		if response == 'sat':
1051			if not recursion:
1052				last_satisfiable_hyps[0] = list (raw_hyps)
1053			if model and not recursion:
1054				assert self.check_model ([h for (h, _) in raw_hyps],
1055					model)
1056		elif response == 'unsat':
1057			fact = '(not (and %s))' % ' '.join ([h
1058				for (h, _) in raw_hyps])
1059			# sending this fact (and not its core-deps) might
1060			# lead to inaccurate cores in the future
1061			if not self.unsat_cores:
1062				self.send ('(assert %s)' % fact)
1063		else:
1064			# couldn't get a useful response from either solver.
1065			trace ('Solvers %s failed to resolve sat/unsat'
1066				% solvs_used)
1067			trace ('last solver result %r' % response)
1068			raise SolverFailure (response)
1069		return response
1070
1071	def get_unsat_core_tags (self, fact_names, hyps):
1072		names = set (fact_names)
1073		trace ('uc names: %s' % names)
1074		core = [hyps[name] for name in names
1075			if name.startswith ('hyp')]
1076		for s in fact_names:
1077			if s.startswith ('assert'):
1078				n = int (s[6:])
1079				core.append (self.assertions[n][1])
1080		trace ('uc tags: %s' % core)
1081		return core
1082
1083	def write_solv_script (self, f, input_msgs, solver = slow_solver,
1084			only_if_is_model = False):
1085		if solver.mem_mode == '8':
1086			smt_convs = word8_smt_convs
1087		else:
1088			smt_convs = word32_smt_convs
1089		for msg in self.preamble (solver):
1090			msg = msg.format (** smt_convs)
1091			f.write (msg + '\n')
1092		for (msg, is_model) in self.replayable:
1093			if only_if_is_model and not is_model:
1094				continue
1095			msg = msg.format (** smt_convs)
1096			f.write (msg + '\n')
1097
1098		for msg in input_msgs:
1099			msg = msg.format (** smt_convs)
1100			f.write (msg + '\n')
1101
1102		f.flush ()
1103
1104	def exec_slow_solver (self, input_msgs, timeout = None,
1105			use_this_solver = None):
1106		solver = self.slow_solver
1107		if use_this_solver:
1108			solver = use_this_solver
1109		if not solver:
1110			return 'no-slow-solver'
1111
1112		(fd, name) = tempfile.mkstemp (suffix='.txt',
1113			prefix='graph-refine-problem-')
1114		tmpfile_write = open (name, 'w')
1115		self.write_solv_script (tmpfile_write, input_msgs,
1116			solver = solver)
1117		tmpfile_write.close ()
1118
1119		proc = subprocess.Popen (solver.args,
1120			stdin = fd, stdout = subprocess.PIPE,
1121			preexec_fn = preexec (timeout))
1122		os.close (fd)
1123		os.unlink (name)
1124
1125		return (proc, proc.stdout)
1126
1127	def use_slow_solver (self, hyps, model = None, unsat_core = None,
1128			use_this_solver = None):
1129		start = time.time ()
1130
1131		cmds = ['(assert %s)' % hyp for (hyp, _) in hyps
1132			] + ['(check-sat)']
1133
1134		if model != None:
1135			cmds.append (self.fetch_model_request ())
1136
1137		if use_this_solver:
1138			solver = use_this_solver
1139		else:
1140			solver = self.slow_solver
1141
1142		(proc, output) = self.exec_slow_solver (cmds,
1143			timeout = solver.timeout, use_this_solver = solver)
1144
1145		response = output.readline ().strip ()
1146		if model != None and response == 'sat':
1147			assert self.fetch_model_response (model,
1148				stream = output)
1149		if unsat_core != None and response == 'unsat':
1150			trace ('WARNING no unsat core from %s' % solver.name)
1151			unsat_core.extend ([tag for (_, tag) in hyps])
1152
1153		output.close ()
1154
1155		if response not in ['sat', 'unsat']:
1156			trace ('SMT conversation problem after (check-sat)')
1157
1158		end = time.time ()
1159		trace ('Got %r from %s.' % (response, solver.name))
1160		trace ('  after %s' % run_time (end - start, proc))
1161		# adjust to save difficult problems
1162		cutoff_time = save_solv_example_time[0]
1163		if cutoff_time != -1 and end - start > cutoff_time:
1164			save_solv_example (self, cmds,
1165				comments = ['reference time %s seconds' % (end - start)])
1166
1167		if model:
1168			assert self.check_model ([h for (h, _) in hyps], model)
1169
1170		return response
1171
1172	def add_parallel_solver (self, k, hyps, model = None,
1173			use_this_solver = None):
1174		cmds = ['(assert %s)' % hyp for hyp in hyps] + ['(check-sat)']
1175
1176		if model != None:
1177			cmds.append (self.fetch_model_request ())
1178
1179		trace ('  --> new parallel solver %s' % str (k))
1180
1181		if k in self.parallel_solvers:
1182			raise IndexError ('duplicate parallel solver ID', k)
1183		solver = self.slow_solver
1184		if use_this_solver:
1185			solver = use_this_solver
1186		(proc, output) = self.exec_slow_solver (cmds,
1187			timeout = solver.timeout, use_this_solver = solver)
1188		self.parallel_solvers[k] = (hyps, proc, output, solver, model)
1189
1190	def wait_parallel_solver_step (self):
1191		import select
1192		assert self.parallel_solvers
1193		fds = dict ([(output.fileno (), k) for (k, (_, _, output, _, _))
1194			in self.parallel_solvers.iteritems ()])
1195		try:
1196			(rlist, _, _) = select.select (fds.keys (), [], [])
1197		except KeyboardInterrupt, e:
1198			self.close_parallel_solvers (reason = 'interrupted')
1199			raise e
1200		k = fds[rlist.pop ()]
1201		(hyps, proc, output, solver, model) = self.parallel_solvers[k]
1202		del self.parallel_solvers[k]
1203		response = output.readline ().strip ()
1204		trace ('  <-- parallel solver %s closed: %s' % (k, response))
1205		trace ('      after %s' % run_time (None, proc))
1206		if response not in ['sat', 'unsat']:
1207			trace ('SMT conversation problem in parallel solver')
1208		trace ('Got %r from %s in parallel.' % (response, solver.name))
1209		m = {}
1210		check = None
1211		if response == 'sat':
1212			last_satisfiable_hyps[0] = hyps
1213		if k[0] != 'ModelRepair':
1214			if model == None or response != 'sat':
1215				output.close ()
1216				return (k, hyps, response)
1217		# model issues
1218		m = {}
1219		if model != None:
1220			res = self.fetch_model_response (m, stream = output)
1221		output.close ()
1222		if model != None and not res:
1223			# just drop this solver at this point
1224			trace ('failed to extract model.')
1225			return None
1226		if k[0] == 'ModelRepair':
1227			(_, k, i) = k
1228			(state, hyps) = self.parallel_model_states[k]
1229		else:
1230			i = 0
1231			state = None
1232		res = self.check_model_iteration (hyps, state, (response, m))
1233		(kind, details) = res
1234		if kind == 'Abort':
1235			return None
1236		elif kind == 'Result':
1237			model.update (details)
1238			return (k, hyps, 'sat')
1239		elif kind == 'Continue':
1240			(solv, test_hyps, state) = details
1241			self.parallel_model_states[k] = (state, hyps)
1242			k = ('ModelRepair', k, i + 1)
1243			self.add_parallel_solver (k, test_hyps,
1244				use_this_solver = solv, model = model)
1245			return None
1246
1247	def wait_parallel_solver (self):
1248		while True:
1249			assert self.parallel_solvers
1250			try:
1251				res = self.wait_parallel_solver_step ()
1252			except ConversationProblem, e:
1253				continue
1254			if res != None:
1255				return res
1256
1257	def close_parallel_solvers (self, ks = None, reason = '?'):
1258		if ks == None:
1259			ks = self.parallel_solvers.keys ()
1260		else:
1261			ks = [k for k in ks if k in self.parallel_solvers]
1262		solvs = [(proc, output) for (_, proc, output, _, _)
1263			in [self.parallel_solvers[k] for k in ks]]
1264		if ks:
1265			trace (' X<-- %d parallel solvers killed: %s'
1266				% (len (ks), reason))
1267		for k in ks:
1268			del self.parallel_solvers[k]
1269		procs = [proc for (proc, _) in solvs]
1270		outputs = [output for (_, output) in solvs]
1271		for proc in procs:
1272			os.killpg (proc.pid, signal.SIGTERM)
1273		for output in outputs:
1274			output.close ()
1275		for proc in procs:
1276			os.killpg (proc.pid, signal.SIGKILL)
1277			proc.wait ()
1278
1279	def parallel_check_hyps (self, hyps, env, model = None):
1280		"""test a series of keyed hypotheses [(k1, h1), (k2, h2) ..etc]
1281		either returns (True, -) all hypotheses true
1282		or (False, ki) i-th hypothesis unprovable"""
1283		hyps = [(k, hyp) for (k, hyp) in hyps
1284			if not self.test_hyp (hyp, env, force_solv = 'Fast',
1285				catch = True, hyp_name = "('hyp', %s)" % k)]
1286		assert not self.parallel_solvers
1287		if not hyps:
1288			return ('unsat', None)
1289		all_hyps = foldr1 (syntax.mk_and, [h for (k, h) in hyps])
1290		def spawn ((k, hyp), stratkey):
1291			goal = smt_expr (syntax.mk_not (hyp), env, self)
1292			[self.add_parallel_solver ((solver.name, strat, k),
1293					[goal], use_this_solver = solver,
1294					model = model)
1295				for (solver, strat) in self.strategy
1296				if strat == stratkey]
1297		if len (hyps) > 1:
1298			spawn ((None, all_hyps), 'all')
1299		spawn (hyps[0], 'hyp')
1300		solved = 0
1301		while True:
1302			((nm, strat, k), _, res) = self.wait_parallel_solver ()
1303			if strat == 'all' and res == 'unsat':
1304				trace ('  -- hyps all confirmed by %s' % nm)
1305				break
1306			elif strat == 'hyp' and res == 'sat':
1307				trace ('  -- hyp refuted by %s' % nm)
1308				break
1309			elif strat == 'hyp' and res == 'unsat':
1310				ks = [(solver.name, strat, k)
1311					for (solver, strat) in self.strategy]
1312				self.close_parallel_solvers (ks,
1313					reason = 'hyp shown unsat')
1314				solved += 1
1315				if solved < len (hyps):
1316					spawn (hyps[solved], 'hyp')
1317				else:
1318					trace ('  - hyps confirmed individually')
1319					break
1320			if not self.parallel_solvers:
1321				res = 'timeout'
1322				trace ('  - all solvers timed out or failed.')
1323				break
1324		self.close_parallel_solvers (reason = ('checked %r' % res))
1325		return (res, k)
1326
1327	def parallel_test_hyps (self, hyps, env, model = None):
1328		(res, k) = self.parallel_check_hyps (hyps, env, model)
1329		return (res == 'unsat', k, res)
1330
1331	def slow_solver_multisat (self, hyps, model = None, timeout = 300):
1332		trace ('multisat check.')
1333		start = time.time ()
1334
1335		cmds = []
1336		for hyp in hyps:
1337			cmds.extend (['(assert %s)' % hyp, '(check-sat)'])
1338			if model != None:
1339				cmds.append (self.fetch_model_request ())
1340		(proc, output) = self.exec_slow_solver (cmds, timeout = timeout)
1341
1342		assert hyps
1343		for (i, hyp) in enumerate (hyps):
1344			trace ('multisat checking %s' % hyp)
1345			response = output.readline ().strip ()
1346			if response == 'sat':
1347				if model != None:
1348					model.clear ()
1349					most_sat = hyps[: i + 1]
1350					assert self.fetch_model_response (model,
1351						stream = output)
1352			else:
1353				self.solver = None
1354				if i == 0 and response == 'unsat':
1355					self.send ('(assert (not %s))' % hyp)
1356				if i > 0:
1357					if response != 'unsat':
1358						trace ('conversation problem:')
1359						trace ('multisat got %r' % response)
1360					response = 'sat'
1361				break
1362
1363		if model:
1364			assert self.check_model (most_sat, model)
1365
1366		end = time.time ()
1367		trace ('multisat final result: %r after %s' % (response,
1368			run_time (end - start, proc)))
1369		output.close ()
1370
1371		return response
1372
1373	def fetch_model_request (self):
1374		vs = self.model_vars
1375		exprs = self.model_exprs
1376
1377		trace ('will fetch model%s for %d vars and %d compound exprs.'
1378			% (self.name_ext, len (vs), len (exprs)))
1379
1380		vs2 = tuple (vs) + tuple ([nm for (nm, typ) in exprs.values ()])
1381		return '(get-value (%s))' % ' '.join (vs2)
1382
1383	def fetch_model_response (self, model, stream = None, recursion = False):
1384		if stream == None:
1385			stream = self.online_solver.stdout
1386		values = get_s_expression (stream,
1387				'fetch_model_response')
1388		if values == None:
1389			trace ('Failed to fetch model!')
1390			return None
1391
1392		expected_set = set (list (self.model_vars)
1393			+ [nm for (nm, typ) in self.model_exprs.values ()])
1394		malformed = [v for v in values if len (v) != 2
1395			or v[0] not in expected_set]
1396		if malformed:
1397			trace ('bad model response components:')
1398			for v in malformed:
1399				trace (repr (v))
1400			return None
1401
1402		filt_values = [(nm, v) for (nm, v) in values
1403			if type (v) == str or '_' in v
1404			if set (v) != set (['?'])]
1405		dropped = len (values) - len (filt_values)
1406		if dropped:
1407			trace ('Dropped %d of %d values' % (dropped, len (values)))
1408			if recursion:
1409				trace (' .. aborting recursive attempt.')
1410				return None
1411
1412		abbrvs = [(sexp, name) for (sexp, (name, typ))
1413			in self.model_exprs.iteritems ()]
1414
1415		r = make_model (filt_values, model, abbrvs)
1416		if dropped:
1417			model[('IsIncomplete', None)] = True
1418		return r
1419
1420	def get_arbitrary_vars (self, typ):
1421		self.arbitrary_vars.setdefault (typ, [])
1422		vs = self.arbitrary_vars[typ]
1423		def add ():
1424			v = self.add_var ('arbitary-var', typ, kind = 'Aux')
1425			vs.append (v)
1426			return v
1427		import itertools
1428		return itertools.chain (vs, itertools.starmap (add,
1429			itertools.repeat ([])))
1430
1431	def force_model_accuracy_hyps (self):
1432		if len (self.model_exprs) == self.last_model_acc_hyps[0]:
1433			return self.last_model_acc_hyps[1]
1434		words = set ()
1435		for (var_nm, typ) in self.model_exprs.itervalues ():
1436			if typ.kind == 'Word':
1437				s = '((_ extract %d %d) %s)' % (typ.num - 1,
1438					typ.num - 2, var_nm)
1439				words.add (s)
1440			elif typ == syntax.boolT:
1441				s = '(ite %s #b10 #b01)' % var_nm
1442				words.add (s)
1443			else:
1444				assert not 'model acc type known', typ
1445		hyps = []
1446		w2T = syntax.Type ('Word', 2)
1447		arb_vars = self.get_arbitrary_vars (w2T)
1448		while words:
1449			while len (words) < 4:
1450				words.add (arb_vars.next ())
1451			[a, b, c, d] = [words.pop () for x in range (4)]
1452			x = arb_vars.next ()
1453			y = arb_vars.next ()
1454			hyps.append (('(word2-xor-scramble %s)'
1455				% ' '.join ([a, x, b, c, y, d]), None))
1456		self.last_model_acc_hyps = (len (self.model_exprs), hyps)
1457		return hyps
1458
1459	def check_model_iteration (self, hyps, state, (res, model)):
1460		"""process an iteration of checking a model. the solvers
1461		sometimes give partially bogus models, which we need to
1462		check for.
1463		the state at any time is (confirmed, test, cand_model, solvs)
1464		We build additional hypotheses (e.g. x = 0) from models.
1465		The 'confirmed' additional hyps have been shown sat together
1466		with the original hyps, and 'test' are under test this
1467		iteration. If the test is true, 'cand_model' will be
1468		confirmed to be a valid (possibly incomplete) model.
1469		We may prune a model down to an incomplete one to try to
1470		find a valid part. The 'solvs' are solvers which have yet
1471		to have a model tested (as candidate) from the current
1472		'confirmed' hyps."""
1473		if state == None:
1474			test = []
1475			confirmed = hyps
1476			assert res == 'sat'
1477			cand_model = None
1478			solvs = self.model_strategy
1479		else:
1480			(confirmed, test, cand_model, solvs) = state
1481
1482		if cand_model and res == 'sat':
1483			if ('IsIncomplete', None) not in cand_model:
1484				return ('Result', cand_model)
1485
1486		if res == 'sat':
1487			if set (test) - set (confirmed):
1488				# confirm experiment
1489				confirmed = sorted (set (confirmed + test))
1490				# progress was made, reenable all solvers
1491				solvs = solvs + [s for s in self.model_strategy
1492					if s not in solvs]
1493		elif res == 'unsat' and not test:
1494			printout ("WARNING: inconsistent sat/unsat.")
1495			inconsistent_hyps.append ((self, hyps, confirmed))
1496			return ('Abort', None)
1497		else:
1498			# since not sat, ignore model contents
1499			model = None
1500
1501		# the candidate model wasn't confirmed, but might we
1502		# learn more by reducing it?
1503		if cand_model and res != 'sat':
1504			reduced = self.reduce_model (cand_model, hyps)
1505			r_hyps = get_model_hyps (reduced, self)
1506			solv = (solvs + self.model_strategy)[0]
1507			if set (r_hyps) - set (confirmed):
1508				return ('Continue', (solv, confirmed + r_hyps,
1509					(confirmed, r_hyps, None, solvs)))
1510
1511		# ignore the candidate model now, and try to continue with
1512		# the most recently returned model. that expires the solver
1513		# that produced this model from solvs
1514		solvs = solvs[1:]
1515		if not model and not solvs:
1516			# out of options
1517			return ('Abort', None)
1518		solv = (solvs + self.model_strategy)[0]
1519
1520		if model:
1521			test_hyps = get_model_hyps (model, self)
1522		else:
1523			model = None
1524			test_hyps = []
1525		return ('Continue', (solv, confirmed + test_hyps,
1526			(confirmed, test_hyps, model, solvs)))
1527
1528	def check_model (self, hyps, model):
1529		orig_model = model
1530		state = None
1531		arg = ('sat', dict (model))
1532		while True:
1533			res = self.check_model_iteration (hyps, state, arg)
1534			(kind, details) = res
1535			if kind == 'Abort':
1536				return False
1537			elif kind == 'Result':
1538				orig_model.clear ()
1539				orig_model.update (details)
1540				return True
1541			assert kind == 'Continue'
1542			(solv, test_hyps, state) = details
1543			m = {}
1544			res = self.hyps_sat_raw (test_hyps, model = m,
1545				force_solv = solv, recursion = True)
1546			arg = (res, m)
1547
1548	def reduce_model (self, model, hyps):
1549		all_hyps = hyps + [h for (h, _) in self.assertions]
1550		all_hyps = map (parse_s_expression, all_hyps)
1551		m = reduce_model (model, self, all_hyps)
1552		trace ('reduce model size: %d --> %d' % (len (model), len (m)))
1553		return m
1554
1555	def fetch_model (self, model, recursion = False):
1556		try:
1557			self.write (self.fetch_model_request ())
1558		except IOError, e:
1559			raise ConversationProblem ('fetch-model', 'IOError')
1560		return self.fetch_model_response (model, recursion = recursion)
1561
1562	def get_unsat_core (self):
1563		res = self.prompt_s_expression_inner ('(get-unsat-core)')
1564		if res == None:
1565			return None
1566		if [s for s in res if type (s) != str]:
1567			raise ConversationProblem ('(get-unsat-core)', res)
1568		return res
1569
1570	def check_hyp (self, hyp, env, model = None, force_solv = False,
1571			hyp_name = None):
1572		hyp = smt_expr (hyp, env, self)
1573		return self.check_hyp_raw (hyp, model = model,
1574			force_solv = force_solv, hyp_name = hyp_name)
1575
1576	def test_hyp (self, hyp, env, model = None, force_solv = False,
1577			catch = False, hyp_name = None):
1578		if catch:
1579			try:
1580				res = self.check_hyp (hyp, env, model = model,
1581					force_solv = force_solv,
1582					hyp_name = hyp_name)
1583			except SolverFailure, e:
1584				return False
1585		else:
1586			res = self.check_hyp (hyp, env, model = model,
1587				force_solv = force_solv, hyp_name = hyp_name)
1588		return res == 'unsat'
1589
1590	def assert_fact_smt (self, fact, unsat_tag = None):
1591		self.assertions.append ((fact, unsat_tag))
1592		if unsat_tag and self.unsat_cores:
1593			name = 'assert%d' % len (self.assertions)
1594			self.send ('(assert (! %s :named %s))' % (fact, name),
1595				is_model = False)
1596		else:
1597			self.send ('(assert %s)' % fact)
1598
1599	def assert_fact (self, fact, env, unsat_tag = None):
1600		fact = smt_expr (fact, env, self)
1601		self.assert_fact_smt (fact, unsat_tag = unsat_tag)
1602
1603	def get_smt_derived_oper (self, name, n):
1604		if (name, n) in self.smt_derived_ops:
1605			return self.smt_derived_ops[(name, n)]
1606		if n != 1:
1607			m = n / 2
1608			top = '((_ extract %d %d) x)' % (n - 1, m)
1609			bot = '((_ extract %d 0) x)' % (m - 1)
1610			top_app = '(%s %s)' % (self.get_smt_derived_oper (
1611				name, n - m), top)
1612			top_appx = '((_ zero_extend %d) %s)' % (m, top_app)
1613			bot_app = '(%s %s)' % (self.get_smt_derived_oper (
1614				name, m), bot)
1615			bot_appx = '((_ zero_extend %d) %s)' % (n - m, bot_app)
1616		if name == 'CountLeadingZeroes':
1617			fname = 'bvclz_%d' % n
1618		elif name == 'WordReverse':
1619			fname = 'bvrev_%d' % n
1620		else:
1621			assert not 'name understood', (name, n)
1622		fname = self.smt_name (fname, kind = 'Fun')
1623
1624		if name == 'CountLeadingZeroes' and n == 1:
1625			self.send ('(define-fun %s ((x (_ BitVec 1)))' % fname
1626				+ ' (_ BitVec 1) (ite (= x #b0) #b1 #b0))')
1627		elif name == 'CountLeadingZeroes':
1628			self.send (('(define-fun %s ((x (_ BitVec %d)))'
1629				+ ' (_ BitVec %d) (ite (= %s %s)'
1630				+ ' (bvadd %s %s) %s))')
1631				% (fname, n, n, top, smt_num (0, n - m),
1632					bot_appx, smt_num (m, n), top_appx))
1633		elif name == 'WordReverse' and n == 1:
1634			self.send ('(define-fun %s ((x (_ BitVec 1)))' % fname
1635				+ ' (_ BitVec 1) x)')
1636		elif name == 'WordReverse':
1637			self.send (('(define-fun %s ((x (_ BitVec %d)))'
1638				+ ' (_ BitVec %d) (concat %s %s))')
1639				% (fname, n, n, bot_app, top_app))
1640		else:
1641			assert not True
1642		self.smt_derived_ops[(name, n)] = fname
1643		return fname
1644
1645		# this is how you would test it
1646		num = random.randrange (0, 2 ** n)
1647		clz = len (bin (num + (2 ** n))[3:].split('1')[0])
1648		assert self.check_hyp_raw ('(= (bvclz_%d %s) %s)' %
1649			(n, smt_num (num, n), smt_num (clz, n))) == 'unsat'
1650		num = num >> random.randrange (0, n)
1651		clz = len (bin (num + (2 ** n))[3:].split('1')[0])
1652		assert self.check_hyp_raw ('(= (bvclz_%d %s) %s)' %
1653			(n, smt_num (num, n), smt_num (clz, n))) == 'unsat'
1654
1655	def cache_large_expr (self, s, name, typ):
1656		if s in self.cached_exprs:
1657			return self.cached_exprs[s]
1658		if len (s) < 80:
1659			return s
1660		name2 = self.add_def (name, mk_smt_expr (s, typ), {})
1661		self.cached_exprs[s] = name2
1662		self.cached_exprs[(name2, 'IsCachedExpr')] = True
1663		return name2
1664
1665	def note_ptr (self, p_s):
1666		if p_s in self.ptrs:
1667			p = self.ptrs[p_s]
1668		else:
1669			p = self.add_def ('ptr', mk_smt_expr (p_s, word32T), {})
1670			self.ptrs[p_s] = p
1671		return p
1672
1673	def add_pvalids (self, htd_s, typ, p_s, kind, recursion = False):
1674		htd_sexp = parse_s_expression (htd_s)
1675		if htd_sexp[0] == 'ite':
1676			[cond, l, r] = map (flat_s_expression, htd_sexp[1:])
1677			return '(ite %s %s %s)' % (cond,
1678				self.add_pvalids (l, typ, p_s, kind),
1679				self.add_pvalids (r, typ, p_s, kind))
1680
1681		pvalids = self.pvalids
1682		if htd_s not in pvalids and not recursion:
1683			[_, _, rodata_ptrs] = rodata
1684			if not rodata_ptrs:
1685				rodata_ptrs = []
1686			for (r_addr, r_typ) in rodata_ptrs:
1687				r_addr_s = smt_expr (r_addr, {}, None)
1688				var = self.add_pvalids (htd_s, ('Type', r_typ),
1689					r_addr_s, 'PGlobalValid',
1690					recursion = True)
1691				self.assert_fact_smt (var)
1692
1693		p = self.note_ptr (p_s)
1694
1695		trace ('adding pvalid with type %s' % (typ, ))
1696
1697		if htd_s in pvalids and (typ, p, kind) in pvalids[htd_s]:
1698			return pvalids[htd_s][(typ, p, kind)]
1699		else:
1700			var = self.add_var ('pvalid', boolT)
1701			pvalids.setdefault (htd_s, {})
1702			others = pvalids[htd_s].items()
1703			pvalids[htd_s][(typ, p, kind)] = var
1704
1705			def smtify (((typ, p, kind), var)):
1706				return (typ, kind, mk_smt_expr (p, word32T),
1707					mk_smt_expr (var, boolT))
1708			pdata = smtify (((typ, p, kind), var))
1709			(_, _, p, pv) = pdata
1710			impl_al = mk_implies (pv, mk_align_valid_ineq (typ, p))
1711			self.assert_fact (impl_al, {})
1712			for val in others:
1713				kinds = [val[0][2], pdata[1]]
1714				if ('PWeakValid' in kinds and
1715						'PGlobalValid' not in kinds):
1716					continue
1717				ass = pvalid_assertion1 (pdata, smtify (val))
1718				ass_s = smt_expr (ass, None, None)
1719				self.assert_fact_smt (ass_s, unsat_tag =
1720					('PValid', 1, var, val[1]))
1721				ass = pvalid_assertion2 (pdata, smtify (val))
1722				ass_s = smt_expr (ass, None, None)
1723				self.assert_fact_smt (ass_s,
1724					('PValid', 2, var, val[1]))
1725
1726			trace ('Now %d related pvalids' % len(pvalids[htd_s]))
1727			return var
1728
1729	def get_imm_basis_mems (self, m, accum):
1730		if m[0] == 'ite':
1731			(_, c, l, r) = m
1732			self.get_imm_basis_mems (l, accum)
1733			self.get_imm_basis_mems (r, accum)
1734		elif m[0] in ['store-word32', 'store-word8']:
1735			(_, m, p, v) = m
1736			self.get_imm_basis_mems (m, accum)
1737		elif type (m) == tuple:
1738			assert not 'mem construction understood', m
1739		elif (m, 'IsCachedExpr') in self.cached_exprs:
1740			self.get_imm_basis_mems (self.defs[m], accum)
1741		else:
1742			assert type (m) == str
1743			accum.add (m)
1744
1745	def get_basis_mems (self, m):
1746		# the obvious implementation requires exponential exploration
1747		# and may overrun the recursion limit.
1748		mems = set ()
1749		processed_defs = set ()
1750
1751		self.get_imm_basis_mems (m, mems)
1752		while True:
1753			proc = [m for m in mems if m in self.defs
1754				if m not in processed_defs]
1755			if not proc:
1756				return mems
1757			for m in proc:
1758				self.get_imm_basis_mems (self.defs[m], mems)
1759				processed_defs.add (m)
1760
1761	def add_split_mem_var (self, addr, nm, typ, mem_name = None):
1762		assert typ == builtinTs['Mem']
1763		bot_mem = self.add_var (nm + '_bot', typ, mem_name = mem_name)
1764		top_mem = self.add_var (nm + '_top', typ, mem_name = mem_name)
1765		self.stack_eqs[('StackEqImpliesCheck', top_mem)] = None
1766		return ('SplitMem', addr, top_mem, bot_mem)
1767
1768	def add_implies_stack_eq (self, sp, s1, s2, env):
1769		k = ('ImpliesStackEq', sp, s1, s2)
1770		if k in self.stack_eqs:
1771			return self.stack_eqs[k]
1772
1773		addr = self.add_var ('stack-eq-witness', word32T)
1774		self.assert_fact_smt ('(= (bvand %s #x00000003) #x00000000)'
1775			% addr)
1776		sp_smt = smt_expr (sp, env, self)
1777		self.assert_fact_smt ('(bvule %s %s)' % (sp_smt, addr))
1778		ptr = mk_smt_expr (addr, word32T)
1779		eq = syntax.mk_eq (syntax.mk_memacc (s1, ptr, word32T),
1780			syntax.mk_memacc (s2, ptr, word32T))
1781		stack_eq = self.add_def ('stack-eq', eq, env)
1782		self.stack_eqs[k] = stack_eq
1783		return stack_eq
1784
1785	def get_stack_eq_implies (self, split, st_top, other):
1786		if other[0] == 'SplitMem':
1787			[_, split2, top2, bot2] = other
1788			rhs = top2
1789			cond = '(bvule %s %s)' % (split2, split)
1790		else:
1791			rhs = other
1792			cond = 'true'
1793		self.note_model_expr ('(= %s %s)' % (st_top, rhs), boolT)
1794		mems = set ()
1795		self.get_imm_basis_mems (parse_s_expression (st_top), mems)
1796		assert len (mems) == 1, mems
1797		[st_top_base] = list (mems)
1798		k = ('StackEqImpliesCheck', st_top_base)
1799		assert k in self.stack_eqs, k
1800		assert self.stack_eqs[k] in [None, rhs], [k,
1801			self.stack_eqs[k], rhs]
1802		self.stack_eqs[k] = rhs
1803		return '(=> %s (= %s %s))' % (cond, st_top, rhs)
1804
1805	def get_token (self, string):
1806		if ('Token', string) not in self.tokens:
1807			n = len (self.tokens) + 1
1808			v = self.add_def ("token_%s" % string,
1809				syntax.mk_num (n, token_smt_typ), {})
1810			self.tokens[('Token', string)] = v
1811			self.tokens[('Val', self.defs[v])] = string
1812		return self.tokens[('Token', string)]
1813
1814	def note_mem_dom (self, p, d, md):
1815		self.doms.add ((p, d, md))
1816
1817	def note_model_expr (self, sexpr, typ):
1818		psexpr = parse_s_expression (sexpr)
1819		if psexpr not in self.model_exprs:
1820			s = ''.join ([c for c in sexpr if c not in " ()"])
1821			s = s[:20]
1822			smt_expr = mk_smt_expr (sexpr, typ)
1823			v = self.add_def ('query_' + s, smt_expr, {})
1824			self.model_exprs[psexpr] = (v, typ)
1825
1826	def add_pvalid_dom_assertions (self):
1827		if not self.doms:
1828			return
1829		if cheat_mem_doms:
1830			return
1831		dom = iter (self.doms).next ()[1]
1832
1833		pvs = [(var, (p, typ.size ()))
1834			for env in self.pvalids.itervalues ()
1835			for ((typ, p, kind), var) in env.iteritems ()]
1836		pvs += [('true', (smt_num (start, 32), (end - start) + 1))
1837				for (start, end) in sections.itervalues ()]
1838
1839		pvalid_doms = (pvs, set (self.doms))
1840		if self.pvalid_doms == pvalid_doms:
1841			return
1842
1843		trace ('PValid/Dom complexity: %d, %d' % (len (pvalid_doms[0]),
1844			len (pvalid_doms[1])))
1845		for (var, (p, sz)) in pvs:
1846			if sz > len (self.doms) * 4:
1847				for (q, _, md) in self.doms:
1848					left = '(bvule %s %s)' % (p, q)
1849					right = ('(bvule %s (bvadd %s %s))'
1850						% (q, p, smt_num (sz - 1, 32)))
1851					lhs = '(and %s %s)' % (left, right)
1852					self.assert_fact_smt ('(=> %s %s)'
1853						% (lhs, md))
1854			else:
1855				vs = ['(mem-dom (bvadd %s %s) %s)'
1856						% (p, smt_num (i, 32), dom)
1857					for i in range (sz)]
1858				self.assert_fact_smt ('(=> %s (and %s))'
1859					% (var, ' '.join (vs)))
1860
1861		self.pvalid_doms = pvalid_doms
1862
1863	def narrow_unsat_core (self, solver, asserts):
1864		process = subprocess.Popen (solver[1],
1865			stdin = subprocess.PIPE, stdout = subprocess.PIPE,
1866			preexec_fn = preexec (solver[2]))
1867		self.write_solv_script (process.stdin, [], solver = solver,
1868			only_if_is_model = True)
1869		asserts = list (asserts)
1870		for (i, (ass, tag)) in enumerate (asserts):
1871			process.stdin.write ('(assert (! %s :named uc%d))\n'
1872				% (ass, i))
1873		process.stdin.write ('(check-sat)\n(get-unsat-core)\n')
1874		process.stdin.close ()
1875		try:
1876			res = get_s_expression (process.stdout, '(check-sat)')
1877			core = get_s_expression (process.stdout,
1878				'(get-unsat-core)')
1879		except ConversationProblem, e:
1880			return asserts
1881		trace ('got response %r from %s' % (res, solver[0]))
1882		if res != 'unsat':
1883			return asserts
1884		for s in core:
1885			assert s.startswith ('uc')
1886		return set ([asserts[int (s[2:])] for s in core])
1887
1888	def unsat_core_loop (self, asserts):
1889		asserts = set (asserts)
1890
1891		orig_num_asserts = len (asserts) + 1
1892		while len (asserts) < orig_num_asserts:
1893			orig_num_asserts = len (asserts)
1894			trace ('Entering unsat_core_loop, %d asserts.'
1895				% orig_num_asserts)
1896			for solver in unsat_solver_loop:
1897				asserts = self.narrow_unsat_core (solver,
1898					asserts)
1899				trace (' .. now %d asserts.' % len (asserts))
1900		return set ([tag for (_, tag) in asserts])
1901
1902	def unsat_core_with_loop (self, hyps, env):
1903		unsat_core = []
1904		hyps = [(smt_expr (hyp, env, self), tag) for (hyp, tag) in hyps]
1905		try:
1906			res = self.hyps_sat_raw (hyps, unsat_core = unsat_core)
1907		except ConversationProblem, e:
1908			res = 'unsat'
1909			unsat_core = []
1910		if res != 'unsat':
1911			return res
1912		if unsat_core == []:
1913			core = list (hyps) + list (self.assertions)
1914		else:
1915			unsat_core = set (unsat_core)
1916			core = [(ass, tag) for (ass, tag) in hyps
1917				if tag in unsat_core] + [(ass, tag)
1918				for (ass, tag) in self.assertions
1919				if tag in unsat_core]
1920		return self.unsat_core_loop (core)
1921
1922def merge_envs (envs, solv):
1923	var_envs = {}
1924	for (pc, env) in envs:
1925		pc_str = smt_expr (pc, env, solv)
1926		for (var, s) in env.iteritems ():
1927			var_envs.setdefault(var, {})
1928			var_envs[var].setdefault(s, [])
1929			var_envs[var][s].append (pc_str)
1930
1931	env = {}
1932	for var in var_envs:
1933		its = var_envs[var].items()
1934		(v, _) = its[-1]
1935		for i in range(len(its) - 1):
1936			(v2, pc_strs) = its[i]
1937			if len (pc_strs) > 1:
1938				pc_str = '(or %s)' % (' '.join (pc_strs))
1939			else:
1940				pc_str = pc_strs[0]
1941			v = smt_ifthenelse (pc_str, v2, v, solv)
1942		env[var] = v
1943	return env
1944
1945def fold_assoc_balanced (f, xs):
1946	if len (xs) >= 4:
1947		i = len (xs) / 2
1948		lhs = fold_assoc_balanced (f, xs[:i])
1949		rhs = fold_assoc_balanced (f, xs[i:])
1950		return f (lhs, rhs)
1951	else:
1952		return foldr1 (f, xs)
1953
1954def merge_envs_pcs (pc_envs, solv):
1955	pc_envs = [(pc, env) for (pc, env) in pc_envs if pc != false_term]
1956	if pc_envs == []:
1957		path_cond = false_term
1958	else:
1959		pcs = list (set ([pc for (pc, _) in pc_envs]))
1960		path_cond = fold_assoc_balanced (mk_or, pcs)
1961	env = merge_envs (pc_envs, solv)
1962
1963	return (path_cond, env, len (pc_envs) > 1)
1964
1965def hash_test_hyp (solv, hyp, env, cache):
1966	assert hyp.typ == boolT
1967	s = smt_expr (hyp, env, solv)
1968	if s in cache:
1969		return cache[s]
1970	v = solv.test_hyp (mk_smt_expr (s, boolT), {})
1971	cache[s] = v
1972	return v
1973
1974def hash_test_hyp_fast (solv, hyp, env, cache):
1975	assert hyp.typ == boolT
1976	s = smt_expr (hyp, env, solv)
1977	return cache.get (s)
1978
1979paren_re = re.compile (r"(\(|\))")
1980
1981def parse_s_expressions (ss):
1982	bits = [bit for s in ss for split1 in paren_re.split (s)
1983		for bit in split1.split ()]
1984	def group (n):
1985		if bits[n] != '(':
1986			return (n + 1, bits[n])
1987		xs = []
1988		n = n + 1
1989		while bits[n] != ')':
1990			(n, x) = group (n)
1991			xs.append (x)
1992		return (n + 1, tuple (xs))
1993	(n, v) = group (0)
1994	assert n == len (bits), ss
1995	return v
1996
1997def parse_s_expression (s):
1998	return parse_s_expressions ([s])
1999
2000def smt_to_val (s, toplevel = None):
2001	stores = []
2002	if len (s) == 3 and s[0] == '_' and s[1][:2] == 'bv':
2003		ln = int (s[2])
2004		n = int (s[1][2:])
2005		return syntax.mk_num (n, ln)
2006	elif type (s) == tuple:
2007		assert type (s) != tuple, s
2008	elif s.startswith ('#b'):
2009		return syntax.mk_num (int (s[2:], 2), len (s) - 2)
2010	elif s.startswith ('#x'):
2011		return syntax.mk_num (int (s[2:], 16), (len (s) - 2) * 4)
2012	elif s == 'true':
2013		return true_term
2014	elif s == 'false':
2015		return false_term
2016	assert not 'smt_to_val: smt expr understood', s
2017
2018last_primitive_model = [0]
2019
2020def eval_mem_name_sexp (m, defs, sexp):
2021	import search
2022	while True:
2023		if sexp[0] == 'ite':
2024			(_, c, l, r) = sexp
2025			b = search.eval_model (m, c)
2026			if b == syntax.true_term:
2027				sexp = l
2028			elif b == syntax.false_term:
2029				sexp = r
2030			else:
2031				assert not 'eval_model result understood'
2032		elif sexp[0] == 'store-word32':
2033			(_, sexp, p2, v2) = sexp
2034		else:
2035			assert type (sexp) == str
2036			if sexp in defs:
2037				sexp = defs[sexp]
2038			else:
2039				return sexp
2040
2041def eval_mem_names (m, defs, mem_names):
2042	init_mem_names = {}
2043	for (m_var, naming) in mem_names.iteritems ():
2044		if type (naming) == tuple:
2045			(nm, sexp) = naming
2046			pred = eval_mem_name_sexp (m, defs, sexp)
2047			init_mem_names[m_var] = (nm, pred)
2048		elif type (naming) == str:
2049			m[m_var] = ((naming, ), {})
2050		else:
2051			assert not 'naming kind understood', naming
2052	stack = init_mem_names.keys ()
2053	while stack:
2054		m_var = stack.pop ()
2055		if m_var in m:
2056			continue
2057		(nm, pred) = init_mem_names[m_var]
2058		if pred in m:
2059			(pred_chain, _) = m[pred]
2060			m[m_var] = (pred_chain + (nm,), {})
2061		else:
2062			stack.extend ([m_var, pred])
2063
2064def make_model (sexp, m, abbrvs = [], mem_defs = {}):
2065	last_primitive_model[0] = (sexp, abbrvs)
2066	m_pre = {}
2067	try:
2068		for (nm, v) in sexp:
2069			if type (nm) == tuple and type (v) == tuple:
2070				return False
2071			m_pre[nm] = smt_to_val (v)
2072		for (abbrv_sexp, nm) in abbrvs:
2073			if nm in m_pre:
2074				m_pre[abbrv_sexp] = m_pre[nm]
2075	except IndexError, e:
2076		print 'Error with make_model'
2077		print sexp
2078		raise e
2079	# only commit to adjusting m now we know we'll succeed
2080	m.update (m_pre)
2081	last_10_models.append (m_pre)
2082	last_10_models[:-10] = []
2083	return True
2084
2085def get_model_hyps (model, solv):
2086	return ['(= %s %s)' % (flat_s_expression (x), smt_expr (v, {}, solv))
2087		for (x, v) in model.iteritems ()
2088		if x != ('IsIncomplete', None)]
2089
2090def add_key_model_vs (sexpr, m, solv, vs):
2091	if sexpr[0] == '=>':
2092		(_, lhs, rhs) = sexpr
2093		add_key_model_vs (('or', ('not', lhs), rhs), m, solv, vs)
2094	elif sexpr[0] == 'or':
2095		vals = [(x, get_model_val (x, m)) for x in sexpr[1:]]
2096		true_vals = [x for (x, v) in vals if v == syntax.true_term]
2097		if not true_vals:
2098			for x in sexpr[1:]:
2099				add_key_model_vs (x, m, solv, vs)
2100		elif len (true_vals) == 1:
2101			add_key_model_vs (true_vals[0], m, solv, vs)
2102		else:
2103			vs.add (sexpr)
2104	elif sexpr[0] == 'and':
2105		vals = [(x, get_model_val (x, m)) for x in sexpr[1:]]
2106		false_vals = [x for (x, v) in vals if v == syntax.false_term]
2107		if not false_vals:
2108			for x in sexpr[1:]:
2109				add_key_model_vs (x, m, solv, vs)
2110		elif len (false_vals) == 1:
2111			add_key_model_vs (false_vals[0], m, solv, vs)
2112		else:
2113			vs.add (sexpr)
2114	elif sexpr[0] == 'ite':
2115		(_, p, x, y) = sexpr
2116		v = get_model_val (p, m)
2117		add_key_model_vs (p, m, solv, vs)
2118		if v == syntax.true_term:
2119			add_key_model_vs (x, m, solv, vs)
2120		if v == syntax.false_term:
2121			add_key_model_vs (y, m, solv, vs)
2122	elif type (sexpr) == str:
2123		if sexpr not in vs:
2124			vs.add (sexpr)
2125			if sexpr in solv.defs:
2126				add_key_model_vs (solv.defs[sexpr], m, solv, vs)
2127	else:
2128		for x in sexpr[1:]:
2129			add_key_model_vs (x, m, solv, vs)
2130
2131def get_model_val (sexpr, m, toplevel = None):
2132	import search
2133	try:
2134		return search.eval_model (m, sexpr)
2135	except AssertionError, e:
2136		# this is awful, but happens sometimes because we're
2137		# evaluating in incomplete models
2138		return None
2139
2140last_model_to_reduce = [0]
2141
2142def reduce_model (m, solv, hyps):
2143	last_model_to_reduce[0] = (m, solv, hyps)
2144	vs = set ()
2145	for hyp in hyps:
2146		add_key_model_vs (hyp, m, solv, vs)
2147	return dict ([(k, m[k]) for k in m if k in vs])
2148
2149def flat_s_expression (s):
2150	if type(s) == tuple:
2151		return '(' + ' '.join (map (flat_s_expression, s)) + ')'
2152	else:
2153		return s
2154
2155pvalid_type_map = {}
2156
2157def fun_cond_test (fun, unsats = None):
2158	if unsats == None:
2159		unsats = []
2160	ns = [n for n in fun.reachable_nodes (simplify = True)
2161		if fun.nodes[n].get_conts ()[1:] == ['Err']
2162		if fun.nodes[n].cond != syntax.false_term]
2163	if not ns:
2164		return
2165	solv = Solver ()
2166	for n in ns:
2167		vs = syntax.get_node_rvals (fun.nodes[n])
2168		env = dict ([((nm, typ), solv.add_var (nm, typ))
2169			for (nm, typ) in vs.iteritems ()])
2170		r = solv.test_hyp (syntax.mk_not (fun.nodes[n].cond), env)
2171		if r == True:
2172			unsats.append ((fun.name, n))
2173	return unsats
2174
2175def cond_tests ():
2176	unsats = []
2177	from target_objects import functions
2178	[fun_cond_test (fun, unsats) for (f, fun) in functions.iteritems ()]
2179	assert not unsats, unsats
2180
2181#def compile_struct_pvalid ():
2182#def compile_pvalids ():
2183def quick_test (force_solv = False):
2184	"""quick test that the solver supports the needed features."""
2185	fs = force_solv
2186	solv = Solver ()
2187	solv.assert_fact (true_term, {})
2188	assert solv.check_hyp (false_term, {}, force_solv = fs) == 'sat'
2189	assert solv.check_hyp (true_term, {}, force_solv = fs) == 'unsat'
2190	v = syntax.mk_var ('v', word32T)
2191	z = syntax.mk_word32 (0)
2192	env = {('v', word32T): solv.add_var ('v', word32T)}
2193	solv.assert_fact (syntax.mk_eq (v, z), env)
2194	m = {}
2195	assert solv.check_hyp (false_term, {}, model = m,
2196		force_solv = fs) == 'sat'
2197	assert m.get ('v') == z, m
2198
2199def test ():
2200	solverlist = find_solverlist_file ()
2201	print 'Loaded solver list from %s' % solverlist
2202	quick_test ()
2203	quick_test (force_solv = 'Slow')
2204	print 'Solver self-test successful'
2205
2206if __name__ == "__main__":
2207	import sys, target_objects
2208	if sys.argv[1:] == ['testq']:
2209		target_objects.tracer[0] = lambda x, y: ()
2210		test ()
2211	elif sys.argv[1:] == ['test']:
2212		test ()
2213	elif sys.argv[1:] == ['ctest']:
2214		cond_tests ()
2215
2216
2217