1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7import solver
8from solver import mk_smt_expr, to_smt_expr, smt_expr
9import check
10from check import restr_others, loops_to_split, ProofNode
11from rep_graph import (mk_graph_slice, vc_num, vc_offs, vc_upto,
12	vc_double_range, VisitCount, vc_offset_upto)
13import rep_graph
14from syntax import (mk_and, mk_cast, mk_implies, mk_not, mk_uminus, mk_var,
15	foldr1, boolT, word32T, word8T, builtinTs, true_term, false_term,
16	mk_word32, mk_word8, mk_times, Expr, Type, mk_or, mk_eq, mk_memacc,
17	mk_num, mk_minus, mk_plus, mk_less)
18import syntax
19import logic
20
21from target_objects import trace, printout
22import target_objects
23import itertools
24
25last_knowledge = [1]
26
27class NoSplit(Exception):
28	pass
29
30def get_loop_var_analysis_at (p, n):
31	k = ('search_loop_var_analysis', n)
32	if k in p.cached_analysis:
33		return p.cached_analysis[k]
34	for hook in target_objects.hooks ('loop_var_analysis'):
35		res = hook (p, n)
36		if res != None:
37			p.cached_analysis[k] = res
38			return res
39	var_deps = p.compute_var_dependencies ()
40	res = p.get_loop_var_analysis (var_deps, n)
41	p.cached_analysis[k] = res
42	return res
43
44def get_loop_vars_at (p, n):
45	vs = [var for (var, data) in get_loop_var_analysis_at (p, n)
46			if data == 'LoopVariable'] + [mk_word32 (0)]
47	vs.sort ()
48	return vs
49
50default_loop_N = 3
51
52last_proof = [None]
53
54def build_proof (p):
55	init_hyps = check.init_point_hyps (p)
56	proof = build_proof_rec (default_searcher, p, (), list (init_hyps))
57
58	trace ('Built proof for %s' % p.name)
59	printout (repr (proof))
60	last_proof[0] = proof
61
62	return proof
63
64def split_sample_set (bound):
65	ns = (range (10) + range (10, 20, 2)
66		+ range (20, 40, 5) + range (40, 100, 10)
67		+ range (100, 1000, 50))
68	return [n for n in ns if n < bound]
69
70last_find_split_limit = [0]
71
72def find_split_limit (p, n, restrs, hyps, kind, bound = 51, must_find = True,
73		hints = [], use_rep = None):
74	tag = p.node_tags[n][0]
75	trace ('Finding split limit: %d (%s)' % (n, tag))
76	last_find_split_limit[0] = (p, n, restrs, hyps, kind)
77	if use_rep == None:
78		rep = mk_graph_slice (p, fast = True)
79	else:
80		rep = use_rep
81	check_order = hints + split_sample_set (bound) + [bound]
82	# bounds strictly outside this range won't be considered
83	bound_range = [0, bound]
84	best_bound_found = [None]
85	def check (i):
86		if i < bound_range[0]:
87			return True
88		if i > bound_range[1]:
89			return False
90		restrs2 = restrs + ((n, VisitCount (kind, i)), )
91		pc = rep.get_pc ((n, restrs2))
92		restrs3 = restr_others (p, restrs2, 2)
93		epc = rep.get_pc (('Err', restrs3), tag = tag)
94		hyp = mk_implies (mk_not (epc), mk_not (pc))
95		res = rep.test_hyp_whyps (hyp, hyps)
96		if res:
97			trace ('split limit found: %d' % i)
98			bound_range[1] = i - 1
99			best_bound_found[0] = i
100		else:
101			bound_range[0] = i + 1
102		return res
103
104	map (check, check_order)
105	while bound_range[0] <= bound_range[1]:
106		split = (bound_range[0] + bound_range[1]) / 2
107		check (split)
108
109	bound = best_bound_found[0]
110	if bound == None:
111		trace ('No split limit found for %d (%s).' % (n, tag))
112		if must_find:
113			assert not 'split limit found'
114	return bound
115
116def get_split_limit (p, n, restrs, hyps, kind, bound = 51,
117		must_find = True, est_bound = 1, hints = None):
118	k = ('SplitLimit', n, restrs, tuple (hyps), kind)
119	if k in p.cached_analysis:
120		(lim, prev_bound) = p.cached_analysis[k]
121		if lim != None or bound <= prev_bound:
122			return lim
123	if hints == None:
124		hints = [est_bound, est_bound + 1, est_bound + 2]
125	res = find_split_limit (p, n, restrs, hyps, kind,
126		hints = hints, must_find = must_find, bound = bound)
127	p.cached_analysis[k] = (res, bound)
128	return res
129
130def init_case_splits (p, hyps, tags = None):
131	if 'init_case_splits' in p.cached_analysis:
132		return p.cached_analysis['init_case_splits']
133	if tags == None:
134		tags = p.pairing.tags
135	poss = logic.possible_graph_divs (p)
136	if len (set ([p.node_tags[n][0] for n in poss])) < 2:
137		return None
138	rep = rep_graph.mk_graph_slice (p)
139	assert all ([p.nodes[n].kind == 'Cond' for n in poss])
140	pc_map = logic.dict_list ([(rep.get_pc ((c, ())), c)
141		for n in poss for c in p.nodes[n].get_conts ()
142		if c not in p.loop_data])
143	no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()])
144	err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag))
145		for tag in p.pairing.tags]
146	knowledge = EqSearchKnowledge (rep, hyps + err_pc_hyps, list (pc_map))
147	last_knowledge[0] = knowledge
148	pc_ids = knowledge.classify_vs ()
149	id_n_map = logic.dict_list ([(i, n) for (pc, i) in pc_ids.iteritems ()
150		for n in pc_map[pc]])
151	tag_div_ns = [[[n for n in ns if p.node_tags[n][0] == t] for t in tags]
152		for (i, ns) in id_n_map.iteritems ()]
153	split_pairs = [(l_ns[0], r_ns[0]) for (l_ns, r_ns) in tag_div_ns
154		if l_ns and r_ns]
155	p.cached_analysis['init_case_splits'] = split_pairs
156	return split_pairs
157
158case_split_tr = []
159
160def init_proof_case_split (p, restrs, hyps):
161	ps = init_case_splits (p, hyps)
162	if ps == None:
163		return None
164	p.cached_analysis.setdefault ('finished_init_case_splits', [])
165	fin = p.cached_analysis['finished_init_case_splits']
166	known_s = set.union (set (restrs), set (hyps))
167	for rs in fin:
168		if rs <= known_s:
169			return None
170	rep = rep_graph.mk_graph_slice (p)
171	no_loop_restrs = tuple ([(n, vc_num (0)) for n in p.loop_heads ()])
172	err_pc_hyps = [rep_graph.pc_false_hyp ((('Err', no_loop_restrs), tag))
173		for tag in p.pairing.tags]
174	for (n1, n2) in ps:
175		pc = rep.get_pc ((n1, ()))
176		if rep.test_hyp_whyps (pc, hyps + err_pc_hyps):
177			continue
178		if rep.test_hyp_whyps (mk_not (pc), hyps + err_pc_hyps):
179			continue
180		case_split_tr.append ((n1, restrs, hyps))
181		return ('CaseSplit', ((n1, p.node_tags[n1][0]), [n1, n2]))
182	fin.append (known_s)
183	return None
184
185# TODO: deal with all the code duplication between these two searches
186class EqSearchKnowledge:
187	def __init__ (self, rep, hyps, vs):
188		self.rep = rep
189		self.hyps = hyps
190		self.v_ids = dict ([(v, 1) for v in vs])
191		self.model_trace = []
192		self.facts = set ()
193		self.premise = foldr1 (mk_and, map (rep.interpret_hyp, hyps))
194
195	def add_model (self, m):
196		self.model_trace.append (m)
197		update_v_ids_for_model2 (self, self.v_ids, m)
198
199	def hyps_add_model (self, hyps):
200		if hyps:
201			test_expr = foldr1 (mk_and, hyps)
202		else:
203			# we want to learn something, either a new model, or
204			# that all hyps are true. if there are no hyps,
205			# learning they're all true is learning nothing.
206			# instead force a model
207			test_expr = false_term
208		test_expr = mk_implies (self.premise, test_expr)
209		m = {}
210		(r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)],
211			{}, model = m)
212		if r == 'unsat':
213			if not hyps:
214				trace ('WARNING: EqSearchKnowledge: premise unsat.')
215				trace ("  ... learning procedure isn't going to work.")
216			for hyp in hyps:
217				self.facts.add (hyp)
218		else:
219			assert r == 'sat', r
220			self.add_model (m)
221
222	def classify_vs (self):
223		while not self.facts:
224			hyps = v_id_eq_hyps (self.v_ids)
225			if not hyps:
226				break
227			self.hyps_add_model (hyps)
228		return self.v_ids
229
230def update_v_ids_for_model2 (knowledge, v_ids, m):
231	# first update the live variables
232	ev = lambda v: eval_model_expr (m, knowledge.rep.solv, v)
233	groups = logic.dict_list ([((k, ev (v)), v)
234		for (v, k) in v_ids.iteritems ()])
235	v_ids.clear ()
236	for (i, kt) in enumerate (sorted (groups)):
237		for v in groups[kt]:
238			v_ids[v] = i
239
240def v_id_eq_hyps (v_ids):
241	groups = logic.dict_list ([(k, v) for (v, k) in v_ids.iteritems ()])
242	hyps = []
243	for vs in groups.itervalues ():
244		for v in vs[1:]:
245			hyps.append (mk_eq (v, vs[0]))
246	return hyps
247
248class SearchKnowledge:
249	def __init__ (self, rep, name, restrs, hyps, tags, cand_elts = None):
250		self.rep = rep
251		self.name = name
252		self.restrs = restrs
253		self.hyps = hyps
254		self.tags = tags
255		if cand_elts != None:
256			(loop_elts, r_elts) = cand_elts
257		else:
258			(loop_elts, r_elts) = ([], [])
259		(pairs, vs) = init_knowledge_pairs (rep, loop_elts, r_elts)
260		self.pairs = pairs
261		self.v_ids = vs
262		self.model_trace = []
263		self.facts = set ()
264		self.weak_splits = set ()
265		self.premise = syntax.true_term
266		self.live_pairs_trace = []
267
268	def add_model (self, m):
269		self.model_trace.append (m)
270		update_v_ids_for_model (self, self.pairs, self.v_ids, m)
271
272	def hyps_add_model (self, hyps, assert_progress = True):
273		if hyps:
274			test_expr = foldr1 (mk_and, hyps)
275		else:
276			# we want to learn something, either a new model, or
277			# that all hyps are true. if there are no hyps,
278			# learning they're all true is learning nothing.
279			# instead force a model
280			test_expr = false_term
281		test_expr = mk_implies (self.premise, test_expr)
282		m = {}
283		(r, _) = self.rep.solv.parallel_check_hyps ([(1, test_expr)],
284			{}, model = m)
285		if r == 'unsat':
286			if not hyps:
287				trace ('WARNING: SearchKnowledge: premise unsat.')
288				trace ("  ... learning procedure isn't going to work.")
289				return
290			if assert_progress:
291				assert not (set (hyps) <= self.facts), hyps
292			for hyp in hyps:
293				self.facts.add (hyp)
294		else:
295			assert r == 'sat', r
296			self.add_model (m)
297			if assert_progress:
298				assert self.model_trace[-2:-1] != [m]
299
300	def eqs_add_model (self, eqs, assert_progress = True):
301		preds = [pred for vpair in eqs
302			for pred in expand_var_eqs (self, vpair)
303			if pred not in self.facts]
304
305		self.hyps_add_model (preds,
306			assert_progress = assert_progress)
307
308	def add_weak_split (self, eqs):
309		preds = [pred for vpair in eqs
310                        for pred in expand_var_eqs (self, vpair)]
311		self.weak_splits.add (tuple (sorted (preds)))
312
313	def is_weak_split (self, eqs):
314		preds = [pred for vpair in eqs
315                        for pred in expand_var_eqs (self, vpair)]
316		return tuple (sorted (preds)) in self.weak_splits
317
318def init_knowledge_pairs (rep, loop_elts, cand_r_loop_elts):
319	trace ('Doing search knowledge setup now.')
320	v_is = [(i, i_offs, i_step,
321		[(v, i, i_offs, i_step) for v in get_loop_vars_at (rep.p, i)])
322		for (i, i_offs, i_step) in sorted (loop_elts)]
323	l_vtyps = set ([v[0].typ for (_, _, _, vs) in v_is for v in vs])
324	v_js = [(j, j_offs, j_step,
325		[(v, j, j_offs, j_step) for v in get_loop_vars_at (rep.p, j)
326			if v.typ in l_vtyps])
327		for (j, j_offs, j_step) in sorted (cand_r_loop_elts)]
328	vs = {}
329	for (_, _, _, var_vs) in v_is + v_js:
330		for v in var_vs:
331			vs[v] = (v[0].typ, True)
332	pairs = {}
333	for (i, i_offs, i_step, i_vs) in v_is:
334		for (j, j_offs, j_step, j_vs) in v_js:
335			pair = ((i, i_offs, i_step), (j, j_offs, j_step))
336			pairs[pair] = (i_vs, j_vs)
337	trace ('... done.')
338	return (pairs, vs)
339
340def update_v_ids_for_model (knowledge, pairs, vs, m):
341	rep = knowledge.rep
342	# first update the live variables
343	groups = {}
344	for v in vs:
345		(k, const) = vs[v]
346		groups.setdefault (k, [])
347		groups[k].append ((v, const))
348	k_counter = 1
349	vs.clear ()
350	for k in groups:
351		for (const, xs) in split_group (knowledge, m, groups[k]):
352			for x in xs:
353				vs[x] = (k_counter, const)
354			k_counter += 1
355	# then figure out which pairings are still viable
356	needed_ks = set ()
357	zero = syntax.mk_word32 (0)
358	for (pair, data) in pairs.items ():
359		if data[0] == 'Failed':
360			continue
361		(lvs, rvs) = data
362		lv_ks = set ([vs[v][0] for v in lvs
363			if v[0] == zero or not vs[v][1]])
364		rv_ks = set ([vs[v][0] for v in rvs])
365		miss_vars = lv_ks - rv_ks
366		if miss_vars:
367			lv_miss = [v[0] for v in lvs if vs[v][0] in miss_vars]
368			pairs[pair] = ('Failed', lv_miss.pop ())
369		else:
370			needed_ks.update ([vs[v][0] for v in lvs + rvs])
371	# then drop any vars which are no longer relevant
372	for v in vs.keys ():
373		if vs[v][0] not in needed_ks:
374			del vs[v]
375
376def get_entry_visits_up_to (rep, head, restrs, hyps):
377	"""get the set of nodes visited on the entry path entry
378	to the loop, up to and including the head point."""
379	k = ('loop_visits_up_to', head, restrs, tuple (hyps))
380	if k in rep.p.cached_analysis:
381		return rep.p.cached_analysis[k]
382
383	[entry] = get_loop_entry_sites (rep, restrs, hyps, head)
384	frontier = set ([entry])
385	up_to = set ()
386	loop = rep.p.loop_body (head)
387	while frontier:
388		n = frontier.pop ()
389		if n == head:
390			continue
391		new_conts = [n2 for n2 in rep.p.nodes[n].get_conts ()
392			if n2 in loop if n2 not in up_to]
393		up_to.update (new_conts)
394		frontier.update (new_conts)
395	rep.p.cached_analysis[k] = up_to
396	return up_to
397
398def get_nth_visit_restrs (rep, restrs, hyps, i, visit_num):
399	"""get the nth (visit_num-th) visit to node i, using its loop head
400	as a restriction point. tricky because there may be a loop entry point
401	that brings us in with the loop head before i, or vice-versa."""
402	head = rep.p.loop_id (i)
403	if i in get_entry_visits_up_to (rep, head, restrs, hyps):
404		# node i is in the set visited on the entry path, so
405		# the head is visited no more often than it
406		offs = 0
407	else:
408		# these are visited after the head point on the entry path,
409		# so the head point is visited 1 more time than it.
410		offs = 1
411	return ((head, vc_num (visit_num + offs)), ) + restrs
412
413def get_var_pc_var_list (knowledge, v_i):
414	rep = knowledge.rep
415	(v_i, i, i_offs, i_step) = v_i
416	def get_var (k):
417		restrs2 = get_nth_visit_restrs (rep, knowledge.restrs,
418				knowledge.hyps, i, k)
419		(pc, env) = rep.get_node_pc_env ((i, restrs2))
420		return (to_smt_expr (pc, env, rep.solv),
421			to_smt_expr (v_i, env, rep.solv))
422	return [get_var (i_offs + (k * i_step))
423		for k in [0, 1, 2]]
424
425def expand_var_eqs (knowledge, (v_i, v_j)):
426	if v_j == 'Const':
427		pc_vs = get_var_pc_var_list (knowledge, v_i)
428		(_, v0) = pc_vs[0]
429		return [mk_implies (pc, mk_eq (v, v0))
430			for (pc, v) in pc_vs[1:]]
431	# sorting the vars guarantees we generate the same
432	# mem eqs each time which is important for the solver
433	(v_i, v_j) = sorted ([v_i, v_j])
434	pc_vs = zip (get_var_pc_var_list (knowledge, v_i),
435		get_var_pc_var_list (knowledge, v_j))
436	return [pred for ((pc_i, v_i), (pc_j, v_j)) in pc_vs
437		for pred in [mk_eq (pc_i, pc_j),
438			mk_implies (pc_i, logic.mk_eq_with_cast (v_i, v_j))]]
439
440word_ops = {'bvadd':lambda x, y: x + y, 'bvsub':lambda x, y: x - y,
441	'bvmul':lambda x, y: x * y, 'bvurem':lambda x, y: x % y,
442	'bvudiv':lambda x, y: x / y, 'bvand':lambda x, y: x & y,
443	'bvor':lambda x, y: x | y, 'bvxor': lambda x, y: x ^ y,
444	'bvnot': lambda x: ~ x, 'bvneg': lambda x: - x,
445	'bvshl': lambda x, y: x << y, 'bvlshr': lambda x, y: x >> y}
446
447bool_ops = {'=>':lambda x, y: (not x) or y, '=': lambda x, y: x == y,
448	'not': lambda x: not x, 'true': lambda: True, 'false': lambda: False}
449
450word_ineq_ops = {'=': (lambda x, y: x == y, 'Unsigned'),
451	'bvult': (lambda x, y: x < y, 'Unsigned'),
452	'word32-eq': (lambda x, y: x == y, 'Unsigned'),
453	'bvule': (lambda x, y: x <= y, 'Unsigned'),
454	'bvsle': (lambda x, y: x <= y, 'Signed'),
455	'bvslt': (lambda x, y: x < y, 'Signed'),
456}
457
458def eval_model (m, s, toplevel = None):
459	if s in m:
460		return m[s]
461	if toplevel == None:
462		toplevel = s
463	if type (s) == str:
464		try:
465			result = solver.smt_to_val (s)
466		except Exception, e:
467			trace ('Error with eval_model')
468			trace (toplevel)
469			raise e
470		return result
471
472	op = s[0]
473
474	if op == 'ite':
475		[_, b, x, y] = s
476		b = eval_model (m, b, toplevel)
477		assert b in [false_term, true_term]
478		if b == true_term:
479			result = eval_model (m, x, toplevel)
480		else:
481			result = eval_model (m, y, toplevel)
482		m[s] = result
483		return result
484
485	xs = [eval_model (m, x, toplevel) for x in s[1:]]
486
487	if op[0] == '_' and op[1] in ['zero_extend', 'sign_extend']:
488		[_, ex_kind, n_extend] = op
489		n_extend = int (n_extend)
490		[x] = xs
491		assert x.typ.kind == 'Word' and x.kind == 'Num'
492		if ex_kind == 'sign_extend':
493			val = get_signed_val (x)
494		else:
495			val = get_unsigned_val (x)
496		result = mk_num (val, x.typ.num + n_extend)
497	elif op[0] == '_' and op[1] == 'extract':
498		[_, _, n_top, n_bot] = op
499		n_top = int (n_top)
500		n_bot = int (n_bot)
501		[x] = xs
502		assert x.typ.kind == 'Word' and x.kind == 'Num'
503		length = (n_top - n_bot) + 1
504		result = mk_num ((x.val >> n_bot) & ((1 << length) - 1), length)
505	elif op[0] == 'store-word32':
506		(m, p, v) = xs
507		(naming, eqs) = m
508		eqs = dict (eqs)
509		eqs[p.val] = v.val
510		eqs = tuple (sorted (eqs.items ()))
511		result = (naming, eqs)
512	elif op[0] == 'store-word8':
513		(m, p, v) = xs
514		p_al = p.val & -4
515		shift = (p.val & 3) * 8
516		(naming, eqs) = m
517		eqs = dict (eqs)
518		prev_v = eqs[p_al]
519		mask_v = prev_v & (((1 << 32) - 1) ^ (255 << shift))
520		new_v = mask_v | ((v.val & 255) << shift)
521		eqs[p.val] = new_v
522		eqs = tuple (sorted (eqs.items ()))
523		result = (naming, eqs)
524	elif op[0] == 'load-word32':
525		(m, p) = xs
526		(naming, eqs) = m
527		eqs = dict (eqs)
528		result = syntax.mk_word32 (eqs[p.val])
529	elif op[0] == 'load-word8':
530		(m, p) = xs
531		p_al = p.val & -4
532		shift = (p.val & 3) * 8
533		(naming, eqs) = m
534		eqs = dict (eqs)
535		v = (eqs[p_al] >> shift) & 255
536		result = syntax.mk_word8 (v)
537	elif xs and xs[0].typ.kind == 'Word' and op in word_ops:
538		for x in xs:
539			assert x.kind == 'Num', (s, op, x)
540		result = word_ops[op](* [x.val for x in xs])
541		result = result & ((1 << xs[0].typ.num) - 1)
542		result = Expr ('Num', xs[0].typ, val = result)
543	elif xs and xs[0].typ.kind == 'Word' and op in word_ineq_ops:
544		(oper, signed) = word_ineq_ops[op]
545		if signed == 'Signed':
546			result = oper (* map (get_signed_val, xs))
547		else:
548			assert signed == 'Unsigned'
549			result = oper (* [x.val for x in xs])
550		result = {True: true_term, False: false_term}[result]
551	elif op == 'and':
552		result = all ([x == true_term for x in xs])
553		result = {True: true_term, False: false_term}[result]
554	elif op == 'or':
555		result = bool ([x for x in xs if x == true_term])
556		result = {True: true_term, False: false_term}[result]
557	elif op in bool_ops:
558		assert all ([x.typ == boolT for x in xs])
559		result = bool_ops[op](* [x == true_term for x in xs])
560		result = {True: true_term, False: false_term}[result]
561	else:
562		assert not 's_expr handled', (s, op)
563	m[s] = result
564	return result
565
566def get_unsigned_val (x):
567	assert x.typ.kind == 'Word'
568	assert x.kind == 'Num'
569	bits = x.typ.num
570	v = x.val & ((1 << bits) - 1)
571	return v
572
573def get_signed_val (x):
574	assert x.typ.kind == 'Word'
575	assert x.kind == 'Num'
576	bits = x.typ.num
577	v = x.val & ((1 << bits) - 1)
578	if v >= (1 << (bits - 1)):
579		v = v - (1 << bits)
580	return v
581
582def short_array_str (arr):
583	items = [('%x: %x' % (p.val * 4, v.val))
584		for (p, v) in arr.iteritems ()
585		if type (p) != str]
586	items.sort ()
587	return '{' + ', '.join (items) + '}'
588
589def eval_model_expr (m, solv, v):
590	s = solver.smt_expr (v, {}, solv)
591	s_x = solver.parse_s_expression (s)
592
593	return eval_model (m, s_x)
594
595def model_equal (m, knowledge, vpair):
596	preds = expand_var_eqs (knowledge, vpair)
597	for pred in preds:
598		x = eval_model_expr (m, knowledge.rep.solv, pred)
599		assert x in [syntax.true_term, syntax.false_term]
600		if x == syntax.false_term:
601			return False
602	return True
603
604def get_model_trace (knowledge, m, v):
605	rep = knowledge.rep
606	pc_vs = get_var_pc_var_list (knowledge, v)
607	trace = []
608	for (pc, v) in pc_vs:
609		x = eval_model_expr (m, rep.solv, pc)
610		assert x in [syntax.true_term, syntax.false_term]
611		if x == syntax.false_term:
612			trace.append (None)
613		else:
614			trace.append (eval_model_expr (m, rep.solv, v))
615	return tuple (trace)
616
617def split_group (knowledge, m, group):
618	group = list (set (group))
619	if group[0][0][0].typ == syntax.builtinTs['Mem']:
620		bins = []
621		for (v, const) in group:
622			for i in range (len (bins)):
623				if model_equal (m, knowledge,
624						(v, bins[i][1][0])):
625					bins[i][1].append (v)
626					break
627			else:
628				if const:
629					const = model_equal (m, knowledge,
630						(v, 'Const'))
631				bins.append ((const, [v]))
632		return bins
633	else:
634		bins = {}
635		for (v, const) in group:
636			trace = get_model_trace (knowledge, m, v)
637			if trace not in bins:
638				tconst = len (set (trace) - set ([None])) <= 1
639				bins[trace] = (const and tconst, [])
640			bins[trace][1].append (v)
641		return bins.values ()
642
643def mk_pairing_v_eqs (knowledge, pair, endorsed = True):
644	v_eqs = []
645	(lvs, rvs) = knowledge.pairs[pair]
646	zero = mk_word32 (0)
647	for v_i in lvs:
648		(k, const) = knowledge.v_ids[v_i]
649		if const and v_i[0] != zero:
650			if not endorsed or eq_known (knowledge, (v_i, 'Const')):
651				v_eqs.append ((v_i, 'Const'))
652				continue
653		vs_j = [v_j for v_j in rvs if knowledge.v_ids[v_j][0] == k]
654		if endorsed:
655			vs_j = [v_j for v_j in vs_j
656				if eq_known (knowledge, (v_i, v_j))]
657		if not vs_j:
658			return None
659		v_j = vs_j[0]
660		v_eqs.append ((v_i, v_j))
661	return v_eqs
662
663def eq_known (knowledge, vpair):
664	preds = expand_var_eqs (knowledge, vpair)
665	return set (preds) <= knowledge.facts
666
667def find_split_loop (p, head, restrs, hyps, unfold_limit = 9,
668		node_restrs = None, trace_ind_fails = None):
669	assert p.loop_data[head][0] == 'Head'
670	assert p.node_tags[head][0] == p.pairing.tags[0]
671
672	# the idea is to loop through testable hyps, starting with ones that
673	# need smaller models (the most unfolded models will time out for
674	# large problems like finaliseSlot)
675
676	rep = mk_graph_slice (p, fast = True)
677
678	nec = get_necessary_split_opts (p, head, restrs, hyps)
679	if nec and nec[0] in ['CaseSplit', 'LoopUnroll']:
680		return nec
681	elif nec:
682		i_j_opts = nec
683	else:
684		i_j_opts = default_i_j_opts (unfold_limit)
685
686	if trace_ind_fails == None:
687		ind_fails = []
688	else:
689		ind_fails = trace_ind_fails
690	for (i_opts, j_opts) in i_j_opts:
691		result = find_split (rep, head, restrs, hyps,
692			i_opts, j_opts, node_restrs = node_restrs)
693		if result[0] != None:
694			return result
695		ind_fails.extend (result[1])
696
697	if ind_fails:
698		trace ('Warning: inductive failures: %s' % ind_fails)
699	raise NoSplit ()
700
701def default_i_j_opts (unfold_limit = 9):
702	return mk_i_j_opts (unfold_limit = unfold_limit)
703
704def mk_i_j_opts (i_seq_opts = None, j_seq_opts = None, unfold_limit = 9):
705	if i_seq_opts == None:
706		i_seq_opts = [(0, 1), (1, 1), (2, 1), (3, 1)]
707	if j_seq_opts == None:
708		j_seq_opts = [(0, 1), (0, 2), (1, 1), (1, 2),
709			(2, 1), (2, 2), (3, 1)]
710	all_opts = set (i_seq_opts + j_seq_opts)
711
712	def filt (opts, lim):
713		return [(start, step) for (start, step) in opts
714			if start + (2 * step) + 1 <= lim]
715
716	lims = [(filt (i_seq_opts, lim), filt (j_seq_opts, lim))
717		for lim in range (unfold_limit)
718		if [1 for (start, step) in all_opts
719			if start + (2 * step) + 1 == lim]]
720	lims = [(i_opts, j_opts) for (i_opts, j_opts) in lims
721		if i_opts and j_opts]
722	return lims
723
724necessary_split_opts_trace = []
725
726def get_interesting_linear_series_exprs (p, head):
727	k = ('interesting_linear_series', head)
728	if k in p.cached_analysis:
729		return p.cached_analysis[k]
730	res = logic.interesting_linear_series_exprs (p, head,
731		get_loop_var_analysis_at (p, head))
732	p.cached_analysis[k] = res
733	return res
734
735def split_opt_test (p, tags = None):
736	if not tags:
737		tags = p.pairing.tags
738	heads = [head for head in init_loops_to_split (p, ())
739		if p.node_tags[head][0] == tags[0]]
740	hyps = check.init_point_hyps (p)
741	return [(head, get_necessary_split_opts (p, head, (), hyps))
742		for head in heads]
743
744def interesting_linear_test (p):
745	p.do_analysis ()
746	for head in p.loop_heads ():
747		inter = get_interesting_linear_series_exprs (p, head)
748		hooks = target_objects.hooks ('loop_var_analysis')
749		n_exprs = [(n, expr, offs) for (n, vs) in inter.iteritems ()
750			if not [hook for hook in hooks if hook (p, n) != None]
751			for (kind, expr, offs) in vs]
752		if n_exprs:
753			rep = rep_graph.mk_graph_slice (p)
754		for (n, expr, offs) in n_exprs:
755			restrs = tuple ([(n2, vc) for (n2, vc)
756				in restr_others_both (p, (), 2, 2)
757				if p.loop_id (n2) != p.loop_id (head)])
758			vis1 = (n, ((head, vc_offs (1)), ) + restrs)
759			vis2 = (n, ((head, vc_offs (2)), ) + restrs)
760			pc = rep.get_pc (vis2)
761			imp = mk_implies (pc, mk_eq (rep.to_smt_expr (expr, vis2),
762				rep.to_smt_expr (mk_plus (expr, offs), vis1)))
763			assert rep.test_hyp_whyps (imp, [])
764	return True
765
766last_necessary_split_opts = [0]
767
768def get_necessary_split_opts (p, head, restrs, hyps, tags = None):
769	if not tags:
770		tags = p.pairing.tags
771	[l_tag, r_tag] = tags
772	last_necessary_split_opts[0] = (p, head, restrs, hyps, tags)
773
774	rep = rep_graph.mk_graph_slice (p, fast = True)
775	entries = get_loop_entry_sites (rep, restrs, hyps, head)
776	if len (entries) > 1:
777		return ('CaseSplit', ((entries[0], tags[0]), [entries[0]]))
778	for n in init_loops_to_split (p, restrs):
779		if p.node_tags[n][0] != r_tag:
780			continue
781		entries = get_loop_entry_sites (rep, restrs, hyps, n)
782		if len (entries) > 1:
783			return ('CaseSplit', ((entries[0], r_tag),
784				[entries[0]]))
785
786	stuff = linear_setup_stuff (rep, head, restrs, hyps, tags)
787	if stuff == None:
788		return None
789	seq_eqs = get_matching_linear_seqs (rep, head, restrs, hyps, tags)
790
791	vis = stuff['vis']
792	for v in seq_eqs:
793		if v[0] == 'LoopUnroll':
794			(_, n, est_bound) = v
795			lim = get_split_limit (p, n, restrs, hyps, 'Number',
796				est_bound = est_bound, must_find = False)
797			if lim != None:
798				return ('LoopUnroll', n)
799			continue
800		((n, expr), (n2, expr2), (l_start, l_step), (r_start, r_step),
801			_, _) = v
802		eqs = [rep_graph.eq_hyp ((expr,
803			(vis (n, l_start + (i * l_step)), l_tag)),
804			(expr2, (vis (n2, r_start + (i * r_step)), r_tag)))
805			for i in range (2)]
806		vis_hyp = rep_graph.pc_true_hyp ((vis (n, l_start), l_tag))
807		vis_hyps = [vis_hyp] + stuff['hyps']
808		eq = foldr1 (mk_and, map (rep.interpret_hyp, eqs))
809		m = {}
810		if rep.test_hyp_whyps (eq, vis_hyps, model = m):
811			trace ('found necessary split info: (%s, %s), (%s, %s)'
812				% (l_start, l_step, r_start, r_step))
813			return mk_i_j_opts ([(l_start + i, l_step)
814					for i in range (r_step + 1)],
815				[(r_start + i, r_step)
816					for i in range (l_step + 1)],
817				unfold_limit = 100)
818		n_vcs = entry_path_no_loops (rep, l_tag, m, head)
819		path_hyps = [rep_graph.pc_true_hyp ((n_vc, l_tag)) for n_vc in n_vcs]
820		if rep.test_hyp_whyps (eq, stuff['hyps'] + path_hyps):
821			# immediate case split on difference between entry paths
822			checks = [(stuff['hyps'], eq_hyp, 'eq')
823				for eq_hyp in eqs]
824			return derive_case_split (rep, n_vcs, checks)
825		necessary_split_opts_trace.append ((n, expr, (l_start, l_step),
826			(r_start, r_step), 'Seq check failed'))
827	return None
828
829def linear_setup_stuff (rep, head, restrs, hyps, tags):
830	[l_tag, r_tag] = tags
831	k = ('linear_seq setup', head, restrs, tuple (hyps), tuple (tags))
832	p = rep.p
833	if k in p.cached_analysis:
834		return p.cached_analysis[k]
835
836	assert p.node_tags[head][0] == l_tag
837	l_seq_vs = get_interesting_linear_series_exprs (p, head)
838	if not l_seq_vs:
839		return None
840	r_seq_vs = {}
841	restr_env = {p.loop_id (head): restrs}
842	for n in init_loops_to_split (p, restrs):
843		if p.node_tags[n][0] != r_tag:
844			continue
845		vs = get_interesting_linear_series_exprs (p, n)
846		r_seq_vs.update (vs)
847	if not r_seq_vs:
848		return None
849
850	def vis (n, i):
851		restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, i)
852		return (n, restrs2)
853	smt = lambda expr, n, i: rep.to_smt_expr (expr, vis (n, i))
854	smt_pc = lambda n, i: rep.get_pc (vis (n, i))
855
856	# remove duplicates by concretising
857	l_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset))
858		for n in l_seq_vs
859		for (kind, expr, offs, oset) in l_seq_vs[n]]).values ()
860	r_seq_vs = dict ([(smt (expr, n, 2), (kind, n, expr, offs, oset))
861                for n in r_seq_vs
862		for (kind, expr, offs, oset) in r_seq_vs[n]]).values ()
863
864	hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), r_tag))
865		for n in set ([n for (_, n, _, _, _) in r_seq_vs])]
866	hyps = hyps + [rep_graph.pc_triv_hyp ((vis (n, 3), l_tag))
867		for n in set ([n for (_, n, _, _, _) in l_seq_vs])]
868	hyps = hyps + [check.non_r_err_pc_hyp (tags,
869			restr_others (p, restrs, 2))]
870
871	r = {'l_seq_vs': l_seq_vs, 'r_seq_vs': r_seq_vs,
872		'hyps': hyps, 'vis': vis, 'smt': smt, 'smt_pc': smt_pc}
873	p.cached_analysis[k] = r
874	return r
875
876def get_matching_linear_seqs (rep, head, restrs, hyps, tags):
877	k = ('matching linear seqs', head, restrs, tuple (hyps), tuple (tags))
878	p = rep.p
879	if k in p.cached_analysis:
880		v = p.cached_analysis[k]
881		(x, y) = itertools.tee (v[0])
882		v[0] = x
883		return y
884
885	[l_tag, r_tag] = tags
886	stuff = linear_setup_stuff (rep, head, restrs, hyps, tags)
887	if stuff == None:
888		return []
889
890	hyps = stuff['hyps']
891	vis = stuff['vis']
892
893	def get_model (n, offs):
894		m = {}
895		offs_smt = stuff['smt'] (offs, n, 1)
896		eq = mk_eq (mk_times (offs_smt, mk_num (4, offs_smt.typ)),
897			mk_num (0, offs_smt.typ))
898		ex_hyps = [rep_graph.pc_true_hyp ((vis (n, 1), l_tag)),
899			rep_graph.pc_true_hyp ((vis (n, 2), l_tag))]
900		res = rep.test_hyp_whyps (eq, hyps + ex_hyps, model = m)
901		if not m:
902			necessary_split_opts_trace.append ((n, kind, 'NoModel'))
903			return None
904		return m
905
906	r = (seq_eq
907		for (kind, n, expr, offs, oset) in sorted (stuff['l_seq_vs'])
908		if [v for v in stuff['r_seq_vs'] if v[0] == kind]
909		for m in [get_model (n, offs)]
910		if m
911		for seq_eq in [get_linear_seq_eq (rep, m, stuff,
912					(kind, n, expr, offs, oset)),
913			get_model_r_side_unroll (rep, tags, m,
914				restrs, hyps, stuff)]
915		if seq_eq != None)
916	(x, y) = itertools.tee (r)
917	p.cached_analysis[k] = [y]
918	return x
919
920def get_linear_seq_eq (rep, m, stuff, expr_t1):
921	def get_int_min (expr):
922		v = eval_model_expr (m, rep.solv, expr)
923		assert v.kind == 'Num', v
924		vs = [v.val + (i << v.typ.num) for i in range (-2, 3)]
925		(_, v) = min ([(abs (v), v) for v in vs])
926		return v
927	(kind, n1, expr1, offs1, oset1) = expr_t1
928	smt = stuff['smt']
929	expr_init = smt (expr1, n1, 0)
930	expr_v = get_int_min (expr_init)
931	offs_v = get_int_min (smt (offs1, n1, 1))
932	r_seqs = [(n, expr, offs, oset2,
933			get_int_min (mk_minus (expr_init, smt (expr, n, 0))),
934			get_int_min (smt (offs, n, 0)))
935		for (kind2, n, expr, offs, oset2) in sorted (stuff['r_seq_vs'])
936		if kind2 == kind]
937
938	for (n, expr, offs2, oset2, diff, offs_v2) in sorted (r_seqs):
939		mult = offs_v / offs_v2
940		if offs_v % offs_v2 != 0 or mult > 8:
941			necessary_split_opts_trace.append ((n, expr,
942				'StepWrong', offs_v, offs_v2))
943		elif diff % offs_v2 != 0 or (diff * offs_v2) < 0 or (diff / offs_v2) > 8:
944			necessary_split_opts_trace.append ((n, expr,
945				'StartWrong', diff, offs_v2))
946		else:
947			return ((n1, expr1), (n, expr), (0, 1),
948				(diff / offs_v2, mult), (offs1, offs2),
949				(oset1, oset2))
950	return None
951
952last_r_side_unroll = [None]
953
954def get_model_r_side_unroll (rep, tags, m, restrs, hyps, stuff):
955	p = rep.p
956	[l_tag, r_tag] = tags
957	last_r_side_unroll[0] = (rep, tags, m, restrs, hyps, stuff)
958
959	r_kinds = set ([kind for (kind, n, _, _, _) in stuff['r_seq_vs']])
960	l_visited_ns_vcs = logic.dict_list ([(n, vc)
961		for (tag, n, vc) in rep.node_pc_env_order
962		if tag == l_tag
963		if eval_pc (rep, m, (n, vc))])
964	l_arc_interesting = [(n, vc, kind, expr)
965		for (n, vcs) in l_visited_ns_vcs.iteritems ()
966		if len (vcs) == 1
967		for vc in vcs
968		for (kind, expr)
969			in logic.interesting_node_exprs (p, n, tags = tags)
970		if kind in r_kinds
971		if expr.typ.kind == 'Word']
972	l_kinds = set ([kind for (n, vc, kind, _) in l_arc_interesting])
973
974	# FIXME: cloned
975	def canon_n (n, typ):
976		vs = [n + (i << typ.num) for i in range (-2, 3)]
977		(_, v) = min ([(abs (v), v) for v in vs])
978		return v
979	def get_int_min (expr):
980		v = eval_model_expr (m, rep.solv, expr)
981		assert v.kind == 'Num', v
982		return canon_n (v.val, v.typ)
983	def eval (expr, n, vc):
984		expr = rep.to_smt_expr (expr, (n, vc))
985		return get_int_min (expr)
986
987	val_interesting_map = logic.dict_list ([((kind, eval (expr, n, vc)), n)
988		for (n, vc, kind, expr) in l_arc_interesting])
989
990	smt = stuff['smt']
991
992	for (kind, n, expr, offs, _) in stuff['r_seq_vs']:
993		if kind not in l_kinds:
994			continue
995		if expr.typ.kind != 'Word':
996			continue
997		expr_n = get_int_min (smt (expr, n, 0))
998		offs_n = get_int_min (smt (offs, n, 0))
999		hit = ([i for i in range (64)
1000			if (kind, canon_n (expr_n + (offs_n * i), expr.typ))
1001				in val_interesting_map])
1002		if [i for i in hit if i > 4]:
1003			return ('LoopUnroll', p.loop_id (n), max (hit))
1004	return None
1005
1006last_failed_pairings = []
1007
1008def setup_split_search (rep, head, restrs, hyps,
1009		i_opts, j_opts, unfold_limit = None, tags = None,
1010		node_restrs = None):
1011	p = rep.p
1012
1013	if not tags:
1014		tags = p.pairing.tags
1015	if node_restrs == None:
1016		node_restrs = set (p.nodes)
1017	if unfold_limit == None:
1018		unfold_limit = max ([start + (2 * step) + 1
1019			for (start, step) in i_opts + j_opts])
1020
1021	trace ('Split search at %d, unfold limit %d.' % (head, unfold_limit))
1022
1023	l_tag, r_tag = tags
1024	loop_elts = [(n, start, step) for n in p.splittable_points (head)
1025		if n in node_restrs
1026		for (start, step) in i_opts]
1027	init_to_split = init_loops_to_split (p, restrs)
1028	r_to_split = [n for n in init_to_split if p.node_tags[n][0] == r_tag]
1029	cand_r_loop_elts = [(n2, start, step) for n in r_to_split
1030		for n2 in p.splittable_points (n)
1031		if n2 in node_restrs
1032		for (start, step) in j_opts]
1033
1034	err_restrs = restr_others (p, tuple ([(sp, vc_upto (unfold_limit))
1035		for sp in r_to_split]) + restrs, 1)
1036	nrerr_pc = mk_not (rep.get_pc (('Err', err_restrs), tag = r_tag))
1037
1038	def get_pc (n, k):
1039		restrs2 = get_nth_visit_restrs (rep, restrs, hyps, n, k)
1040		return rep.get_pc ((n, restrs2))
1041
1042	for n in r_to_split:
1043		get_pc (n, unfold_limit)
1044	get_pc (head, unfold_limit)
1045
1046	premise = foldr1 (mk_and, [nrerr_pc] + map (rep.interpret_hyp, hyps))
1047	premise = logic.weaken_assert (premise)
1048
1049	knowledge = SearchKnowledge (rep,
1050		'search at %d (unfold limit %d)' % (head, unfold_limit),
1051		restrs, hyps, tags, (loop_elts, cand_r_loop_elts))
1052	knowledge.premise = premise
1053	last_knowledge[0] = knowledge
1054
1055	# make sure the representation is in sync
1056	rep.test_hyp_whyps (true_term, hyps)
1057
1058	# make sure all mem eqs are being tracked
1059	mem_vs = [v for v in knowledge.v_ids if v[0].typ == builtinTs['Mem']]
1060	for (i, v) in enumerate (mem_vs):
1061		for v2 in mem_vs[:i]:
1062			for pred in expand_var_eqs (knowledge, (v, v2)):
1063				smt_expr (pred, {}, rep.solv)
1064	for v in knowledge.v_ids:
1065		for pred in expand_var_eqs (knowledge, (v, 'Const')):
1066			smt_expr (pred, {}, rep.solv)
1067
1068	return knowledge
1069
1070def get_loop_entry_sites (rep, restrs, hyps, head):
1071	k = ('loop_entry_sites', restrs, tuple (hyps), rep.p.loop_id (head))
1072	if k in rep.p.cached_analysis:
1073		return rep.p.cached_analysis[k]
1074	ns = set ([n for n2 in rep.p.loop_body (head)
1075		for n in rep.p.preds[n2]
1076		if rep.p.loop_id (n) == None])
1077	def npc (n):
1078		return rep_graph.pc_false_hyp (((n, tuple ([(n2, restr)
1079			for (n2, restr) in restrs if n2 != n])),
1080				rep.p.node_tags[n][0]))
1081	res = [n for n in ns if not rep.test_hyp_imp (hyps, npc (n))]
1082	rep.p.cached_analysis[k] = res
1083	return res
1084
1085def rebuild_knowledge (head, knowledge):
1086	i_opts = sorted (set ([(start, step)
1087		for ((_, start, step), _) in knowledge.pairs]))
1088	j_opts = sorted (set ([(start, step)
1089		for (_, (_, start, step)) in knowledge.pairs]))
1090	knowledge2 = setup_split_search (knowledge.rep, head, knowledge.restrs,
1091		knowledge.hyps, i_opts, j_opts)
1092	knowledge2.facts.update (knowledge.facts)
1093	for m in knowledge.model_trace:
1094		knowledge2.add_model (m)
1095	return knowledge2
1096
1097def split_search (head, knowledge):
1098	rep = knowledge.rep
1099	p = rep.p
1100
1101	# test any relevant cached solutions.
1102	p.cached_analysis.setdefault (('v_eqs', head), set ())
1103	v_eq_cache = p.cached_analysis[('v_eqs', head)]
1104	for (pair, eqs) in v_eq_cache:
1105		if pair in knowledge.pairs:
1106			knowledge.eqs_add_model (list (eqs),
1107				assert_progress = False)
1108
1109	while True:
1110		trace ('In %s' % knowledge.name)
1111		trace ('Computing live pairings')
1112		pair_eqs = [(pair, mk_pairing_v_eqs (knowledge, pair))
1113			for pair in sorted (knowledge.pairs)
1114			if knowledge.pairs[pair][0] != 'Failed']
1115		if not pair_eqs:
1116			ind_fails = trace_search_fail (knowledge)
1117			return (None, ind_fails)
1118
1119		endorsed = [(pair, eqs) for (pair, eqs) in pair_eqs
1120			if eqs != None]
1121		trace (' ... %d live pairings, %d endorsed' %
1122			(len (pair_eqs), len (endorsed)))
1123		knowledge.live_pairs_trace.append (len (pair_eqs))
1124		for (pair, eqs) in endorsed:
1125			if knowledge.is_weak_split (eqs):
1126				trace ('  dropping endorsed - probably weak.')
1127				knowledge.pairs[pair] = ('Failed',
1128					'ExpectedSplitWeak', eqs)
1129				continue
1130			split = build_and_check_split (p, pair, eqs,
1131				knowledge.restrs, knowledge.hyps,
1132				knowledge.tags)
1133			if split == None:
1134				knowledge.pairs[pair] = ('Failed',
1135					'SplitWeak', eqs)
1136				knowledge.add_weak_split (eqs)
1137				continue
1138			elif split == 'InductFailed':
1139				knowledge.pairs[pair] = ('Failed',
1140					'InductFailed', eqs)
1141			elif split[0] == 'SingleRevInduct':
1142				return split
1143			else:
1144				v_eq_cache.add ((pair, tuple (eqs)))
1145				trace ('Found split!')
1146				return ('Split', split)
1147		if endorsed:
1148			continue
1149
1150		(pair, _) = pair_eqs[0]
1151		trace ('Testing guess for pair: %s' % str (pair))
1152		eqs = mk_pairing_v_eqs (knowledge, pair, endorsed = False)
1153		assert eqs, pair
1154		knowledge.eqs_add_model (eqs)
1155
1156def build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags):
1157	split = v_eqs_to_split (p, pair, eqs, restrs, hyps, tags = tags)
1158	if split == None:
1159		return None
1160	res = check_split_induct (p, restrs, hyps, split, tags = tags)
1161	if res:
1162		return split
1163	else:
1164		return 'InductFailed'
1165
1166def build_and_check_split (p, pair, eqs, restrs, hyps, tags):
1167	res = build_and_check_split_inner (p, pair, eqs, restrs, hyps, tags)
1168	if res != 'InductFailed':
1169		return res
1170
1171	# induction has failed at this point, but we might be able to rescue
1172	# it one of two different ways.
1173	((l_split, _, l_step), _) = pair
1174	extra = get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step)
1175	if extra:
1176		res = build_and_check_split (p, pair, eqs, restrs, hyps, tags)
1177		# the additional linear eqs get built into the result
1178		if res != 'InductFailed':
1179			return res
1180
1181	(_, (r_split, _, _)) = pair
1182	r_loop = p.loop_id (r_split)
1183	spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_loop)
1184	if spec:
1185		hyp = check.single_induct_resulting_hyp (p, restrs, spec)
1186		hyps2 = hyps + [hyp]
1187		res = build_and_check_split (p, pair, eqs, restrs, hyps2, tags)
1188		if res != 'InductFailed':
1189			return ('SingleRevInduct', spec)
1190	return 'InductFailed'
1191
1192def get_new_extra_linear_seq_eqs (p, restrs, l_split, l_step):
1193	k = ('extra_linear_seq_eqs', l_split, l_step)
1194	if k in p.cached_analysis:
1195		return []
1196	if not [v for (v, data) in get_loop_var_analysis_at (p, l_split)
1197			if data[0] == 'LoopLinearSeries']:
1198		return []
1199	import loop_bounds
1200	lin_series_eqs = loop_bounds.get_linear_series_eqs (p, l_split,
1201		restrs, [], omit_standard = True)
1202	p.cached_analysis[k] = lin_series_eqs
1203	return lin_series_eqs
1204
1205def trace_search_fail (knowledge):
1206	trace (('Exhausted split candidates for %s' % knowledge.name))
1207	fails = [it for it in knowledge.pairs.items ()
1208		if it[1][0] == 'Failed']
1209	last_failed_pairings.append (fails)
1210	del last_failed_pairings[:-10]
1211	fails10 = fails[:10]
1212	trace ('  %d of %d failed pairings:' % (len (fails10),
1213		len (fails)))
1214	for f in fails10:
1215		trace ('    %s' % (f,))
1216	ind_fails = [it for it in fails
1217		if str (it[1][1]) == 'InductFailed']
1218	if ind_fails:
1219		trace (  'Inductive failures!')
1220	else:
1221		trace (  'No inductive failures.')
1222	for f in ind_fails:
1223		trace ('    %s' % (f,))
1224	return ind_fails
1225
1226def find_split (rep, head, restrs, hyps, i_opts, j_opts,
1227		unfold_limit = None, tags = None,
1228		node_restrs = None):
1229	knowledge = setup_split_search (rep, head, restrs, hyps,
1230		i_opts, j_opts, unfold_limit = unfold_limit,
1231		tags = tags, node_restrs = node_restrs)
1232
1233	res = split_search (head, knowledge)
1234
1235	if res[0]:
1236		return res
1237
1238	(models, facts, n_vcs) = most_common_path (head, knowledge)
1239	if not n_vcs:
1240		return res
1241
1242	[tag, _] = knowledge.tags
1243	knowledge = setup_split_search (rep, head, restrs,
1244		hyps + [rep_graph.pc_true_hyp ((n_vc, tag)) for n_vc in n_vcs],
1245		i_opts, j_opts, unfold_limit, tags, node_restrs = node_restrs)
1246	knowledge.facts.update (facts)
1247	for m in models:
1248		knowledge.add_model (m)
1249	res = split_search (head, knowledge)
1250
1251	if res[0] == None:
1252		return res
1253	(_, split) = res
1254	checks = check.split_init_step_checks (rep.p, restrs,
1255                        hyps, split)
1256
1257	return derive_case_split (rep, n_vcs, checks)
1258
1259def most_common_path (head, knowledge):
1260	rep = knowledge.rep
1261	[tag, _] = knowledge.tags
1262	data = logic.dict_list ([(tuple (entry_path_no_loops (rep,
1263			tag, m, head)), m)
1264		for m in knowledge.model_trace])
1265	if len (data) < 2:
1266		return (None, None, None)
1267
1268	(_, path) = max ([(len (data[path]), path) for path in data])
1269	models = data[path]
1270	facts = knowledge.facts
1271	other_n_vcs = set.intersection (* [set (path2) for path2 in data
1272		if path2 != path])
1273
1274	n_vcs = []
1275	pcs = set ()
1276	for n_vc in path:
1277		if n_vc in other_n_vcs:
1278			continue
1279		if rep.p.loop_id (n_vc[0]):
1280			continue
1281		pc = rep.get_pc (n_vc)
1282		if pc not in pcs:
1283			pcs.add (pc)
1284			n_vcs.append (n_vc)
1285	assert n_vcs
1286
1287	return (models, facts, n_vcs)
1288
1289def eval_pc (rep, m, n_vc, tag = None):
1290	hit = eval_model_expr (m, rep.solv, rep.get_pc (n_vc, tag = tag))
1291	assert hit in [syntax.true_term, syntax.false_term], (n_vc, hit)
1292	return hit == syntax.true_term
1293
1294def entry_path (rep, tag, m, head):
1295	n_vcs = []
1296	for (tag2, n, vc) in rep.node_pc_env_order:
1297		if n == head:
1298			break
1299		if tag2 != tag:
1300			continue
1301		if eval_pc (rep, m, (n, vc), tag):
1302			n_vcs.append ((n, vc))
1303	return n_vcs
1304
1305def entry_path_no_loops (rep, tag, m, head = None):
1306	n_vcs = entry_path (rep, tag, m, head)
1307	return [(n, vc) for (n, vc) in n_vcs
1308		if not rep.p.loop_id (n)]
1309
1310last_derive_case_split = [0]
1311
1312def derive_case_split (rep, n_vcs, checks):
1313	last_derive_case_split[0] = (rep.p, n_vcs, checks)
1314	# remove duplicate pcs
1315	n_vcs_uniq = dict ([(rep.get_pc (n_vc), (i, n_vc))
1316		for (i, n_vc) in enumerate (n_vcs)]).values ()
1317	n_vcs = [n_vc for (i, n_vc) in sorted (n_vcs_uniq)]
1318	assert n_vcs
1319	tag = rep.p.node_tags[n_vcs[0][0]][0]
1320	keep_n_vcs = []
1321	test_n_vcs = n_vcs
1322	mk_thyps = lambda n_vcs: [rep_graph.pc_true_hyp ((n_vc, tag))
1323		for n_vc in n_vcs]
1324	while len (test_n_vcs) > 1:
1325		i = len (test_n_vcs) / 2
1326		test_in = test_n_vcs[:i]
1327		test_out = test_n_vcs[i:]
1328		checks2 = [(hyps + mk_thyps (test_in + keep_n_vcs), hyp, nm)
1329			for (hyps, hyp, nm) in checks]
1330		(verdict, _) = check.test_hyp_group (rep, checks2)
1331		if verdict:
1332			# forget n_vcs that were tested out
1333			test_n_vcs = test_in
1334		else:
1335			# focus on n_vcs that were tested out
1336			test_n_vcs = test_out
1337			keep_n_vcs.extend (test_in)
1338	[(n, vc)] = test_n_vcs
1339	return ('CaseSplit', ((n, tag), [n]))
1340
1341def mk_seq_eqs (p, split, step, with_rodata):
1342	# eqs take the form of a number of constant expressions
1343	eqs = []
1344
1345	# the variable 'loop' will be converted to the point in
1346	# the sequence - note this should be multiplied by the step size
1347	loop = mk_var ('%i', word32T)
1348	if step == 1:
1349		minus_loop_step = mk_uminus (loop)
1350	else:
1351		minus_loop_step = mk_times (loop, mk_word32 (- step))
1352
1353	for (var, data) in get_loop_var_analysis_at (p, split):
1354		if data == 'LoopVariable':
1355			if with_rodata and var.typ == builtinTs['Mem']:
1356				eqs.append (logic.mk_rodata (var))
1357		elif data == 'LoopConst':
1358			if var.typ not in syntax.phantom_types:
1359				eqs.append (var)
1360		elif data == 'LoopLeaf':
1361			continue
1362		elif data[0] == 'LoopLinearSeries':
1363			(_, form, _) = data
1364			eqs.append (form (var,
1365				mk_cast (minus_loop_step, var.typ)))
1366		else:
1367			assert not 'var_deps type understood'
1368
1369	k = ('extra_linear_seq_eqs', split, step)
1370	eqs += p.cached_analysis.get (k, [])
1371
1372	return eqs
1373
1374def c_memory_loop_invariant (p, c_sp, a_sp):
1375	def mem_vars (split):
1376		return [v for (v, data) in get_loop_var_analysis_at (p, split)
1377			if v.typ == builtinTs['Mem']
1378			if data == 'LoopVariable']
1379
1380	if mem_vars (a_sp):
1381		return []
1382	# if ASM keeps memory constant through the loop, it is implying this
1383	# is semantically possible in C also, though it may not be
1384	# syntactically the case
1385	# anyway, we have to assert C memory equals *something* inductively
1386	# so we pick C initial memory.
1387	return mem_vars (c_sp)
1388
1389def v_eqs_to_split (p, pair, v_eqs, restrs, hyps, tags = None):
1390	trace ('v_eqs_to_split: (%s, %s)' % pair)
1391
1392	((l_n, l_init, l_step), (r_n, r_init, r_step)) = pair
1393	l_details = (l_n, (l_init, l_step), mk_seq_eqs (p, l_n, l_step, True)
1394		+ [v_i[0] for (v_i, v_j) in v_eqs if v_j == 'Const'])
1395	r_details = (r_n, (r_init, r_step), mk_seq_eqs (p, r_n, r_step, False)
1396		+ c_memory_loop_invariant (p, r_n, l_n))
1397
1398	eqs = [(v_i[0], mk_cast (v_j[0], v_i[0].typ))
1399		for (v_i, v_j) in v_eqs if v_j != 'Const'
1400		if v_i[0] != syntax.mk_word32 (0)]
1401
1402	n = 2
1403	split = (l_details, r_details, eqs, n, (n * r_step) - 1)
1404	trace ('Split: %s' % (split, ))
1405	if tags == None:
1406		tags = p.pairing.tags
1407	hyps = hyps + check.split_loop_hyps (tags, split, restrs, exit = True)
1408
1409	r_max = get_split_limit (p, r_n, restrs, hyps, 'Offset',
1410		bound = (n + 2) * r_step, must_find = False,
1411		hints = [n * r_step, n * r_step + 1])
1412	if r_max == None:
1413		trace ('v_eqs_to_split: no RHS limit')
1414		return None
1415
1416	if r_max > n * r_step:
1417		trace ('v_eqs_to_split: RHS limit not %d' % (n * r_step))
1418		return None
1419	trace ('v_eqs_to_split: split %s' % (split,))
1420	return split
1421
1422def get_n_offset_successes (rep, sp, step, restrs):
1423	loop = rep.p.loop_body (sp)
1424	ns = [n for n in loop if rep.p.nodes[n].kind == 'Call']
1425	succs = []
1426	for i in range (step):
1427		for n in ns:
1428			vc = vc_offs (i + 1)
1429			if n == sp:
1430				vc = vc_offs (i)
1431			n_vc = (n, restrs + tuple ([(sp, vc)]))
1432			(_, _, succ) = rep.get_func (n_vc)
1433			pc = rep.get_pc (n_vc)
1434			succs.append (syntax.mk_implies (pc, succ))
1435	return succs
1436
1437eq_ineq_ops = set (['Equals', 'Less', 'LessEquals',
1438	'SignedLess', 'SignedLessEquals'])
1439
1440def split_linear_eq (cond):
1441	if cond.is_op ('Not'):
1442		[c] = cond.vals
1443		return split_linear_eq (c)
1444	elif cond.is_op (eq_ineq_ops):
1445		return (cond.vals[0], cond.vals[1])
1446	elif cond.is_op ('PArrayValid'):
1447		[htd, typ_expr, p, num] = cond.vals
1448		assert typ_expr.kind == 'Type'
1449		typ = typ_expr.val
1450		return split_linear_eq (logic.mk_array_size_ineq (typ, num, p))
1451	else:
1452		return None
1453
1454def possibly_linear_ineq (cond):
1455	rv = split_linear_eq (cond)
1456	if not rv:
1457		return False
1458	(lhs, rhs) = rv
1459	return logic.possibly_linear (lhs) and logic.possibly_linear (rhs)
1460
1461def linear_const_comparison (p, n, cond):
1462	"""examines a condition. if it is a linear (e.g. Less) comparison
1463	between a linear series variable and a loop-constant expression,
1464	return (linear side, const side), or None if not the case."""
1465	rv = split_linear_eq (cond)
1466	loop_head = p.loop_id (n)
1467	if not rv:
1468		return None
1469	(lhs, rhs) = rv
1470	zero = mk_num (0, lhs.typ)
1471	offs = logic.get_loop_linear_offs (p, loop_head)
1472	(lhs_offs, rhs_offs) = [offs (n, expr) for expr in [lhs, rhs]]
1473	oset = set ([lhs_offs, rhs_offs])
1474	if zero in oset and None not in oset and len (oset) > 1:
1475		if lhs_offs == zero:
1476			return (rhs, lhs)
1477		else:
1478			return (lhs, rhs)
1479	return None
1480
1481def do_linear_rev_test (rep, restrs, hyps, split, eqs_assume, pred, large):
1482	p = rep.p
1483	(tag, _) = p.node_tags[split]
1484	checks = (check.single_loop_rev_induct_checks (p, restrs, hyps, tag,
1485			split, eqs_assume, pred)
1486		+ check.single_loop_rev_induct_base_checks (p, restrs, hyps,
1487			tag, split, large, eqs_assume, pred))
1488
1489	groups = check.proof_check_groups (checks)
1490	for group in groups:
1491		(res, _) = check.test_hyp_group (rep, group)
1492		if not res:
1493			return False
1494	return True
1495
1496def get_extra_assn_linear_conds (expr):
1497	if expr.is_op ('And'):
1498		return [cond for conj in logic.split_conjuncts (expr)
1499			for cond in get_extra_assn_linear_conds (conj)]
1500	if not expr.is_op ('Or'):
1501		return [expr]
1502	arr_vs = [v for v in expr.vals if v.is_op ('PArrayValid')]
1503	if not arr_vs:
1504		return [expr]
1505	[htd, typ_expr, p, num] = arr_vs[0].vals
1506	assert typ_expr.kind == 'Type'
1507	typ = typ_expr.val
1508	less_eq = logic.mk_array_size_ineq (typ, num, p)
1509	assn = logic.mk_align_valid_ineq (('Array', typ, num), p)
1510	return get_extra_assn_linear_conds (assn) + [less_eq]
1511
1512def get_rhs_speculate_ineq (p, restrs, loop_head):
1513	assert p.loop_id (loop_head), loop_head
1514	loop_head = p.loop_id (loop_head)
1515	restrs = tuple ([(n, vc) for (n, vc) in restrs
1516		if p.node_tags[n][0] == p.node_tags[loop_head][0]])
1517	key = ('rhs_speculate_ineq', restrs, loop_head)
1518	if key in p.cached_analysis:
1519		return p.cached_analysis[key]
1520
1521	res = rhs_speculate_ineq (p, restrs, loop_head)
1522	p.cached_analysis[key] = res
1523	return res
1524
1525def get_new_rhs_speculate_ineq (p, restrs, hyps, loop_head):
1526	res = get_rhs_speculate_ineq (p, restrs, loop_head)
1527	if res == None:
1528		return None
1529	(point, _, (pred, _)) = res
1530	hs = [h for h in hyps if point in [n for ((n, _), _) in h.visits ()]
1531		if pred in h.get_vals ()]
1532	if hs:
1533		return None
1534	return res
1535
1536def rhs_speculate_ineq (p, restrs, loop_head):
1537	"""code for handling an interesting case in which the compiler
1538	knows that the RHS program might fail in the future. for instance,
1539	consider a loop that cannot be exited until iterator i reaches value n.
1540	any error condition which implies i < b must hold of i - 1, thus
1541	n <= b.
1542
1543	detects this case and identifies the inequality n <= b"""
1544	body = p.loop_body (loop_head)
1545
1546	# if the loop contains function calls, skip it,
1547	# otherwise we need to figure out whether they terminate
1548	if [n for n in body if p.nodes[n].kind == 'Call']:
1549		return None
1550
1551	exit_nodes = set ([n for n in body for n2 in p.nodes[n].get_conts ()
1552			if n2 != 'Err' if n2 not in body])
1553	assert set ([p.nodes[n].kind for n in exit_nodes]) <= set (['Cond'])
1554
1555	# if there are multiple exit conditions, too hard for now
1556	if len (exit_nodes) > 1:
1557		return None
1558
1559	[exit_n] = list (exit_nodes)
1560	rv = linear_const_comparison (p, exit_n, p.nodes[exit_n].cond)
1561	if not rv:
1562		return None
1563	(linear, const) = rv
1564
1565	err_cond_sites = [(n, p.nodes[n].err_cond ()) for n in body]
1566	err_conds = set ([(n, cond) for (n, err_cond) in err_cond_sites
1567		if err_cond
1568		for assn in logic.split_conjuncts (mk_not (err_cond))
1569		for cond in get_extra_assn_linear_conds (assn)
1570		if possibly_linear_ineq (cond)])
1571	if not err_conds:
1572		return None
1573
1574	assert const.typ.kind == 'Word'
1575	rep = rep_graph.mk_graph_slice (p)
1576	eqs = mk_seq_eqs (p, exit_n, 1, False)
1577	import loop_bounds
1578	eqs += loop_bounds.get_linear_series_eqs (p, exit_n,
1579                restrs, [], omit_standard = True)
1580
1581	large = (2 ** const.typ.num) - 3
1582	const_less = lambda n: mk_less (const, mk_num (n, const.typ))
1583	less_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n,
1584		eqs, const_less (n), large)
1585	const_ge = lambda n: mk_less (mk_num (n, const.typ), const)
1586	ge_test = lambda n: do_linear_rev_test (rep, restrs, [], exit_n,
1587		eqs, const_ge (n), large)
1588
1589	res = logic.binary_search_least (less_test, 1, large)
1590	if res:
1591		return (loop_head, (eqs, 1), (const_less (res), large))
1592	res = logic.binary_search_greatest (ge_test, 0, large)
1593	if res:
1594		return (loop_head, (eqs, 1), (const_ge (res), large))
1595	return None
1596
1597def check_split_induct (p, restrs, hyps, split, tags = None):
1598	"""perform both the induction check and a function-call based check
1599	on successes which can avoid some problematic inductions."""
1600	((l_split, (_, l_step), _), (r_split, (_, r_step), _), _, n, _) = split
1601	if tags == None:
1602		tags = p.pairing.tags
1603
1604	err_hyp = check.split_r_err_pc_hyp (p, split, restrs, tags = tags)
1605	hyps = [err_hyp] + hyps + check.split_loop_hyps (tags, split,
1606		restrs, exit = False)
1607
1608	rep = mk_graph_slice (p)
1609
1610	if not check.check_split_induct_step_group (rep, restrs, hyps, split,
1611			tags = tags):
1612		return False
1613
1614	l_succs = get_n_offset_successes (rep, l_split, l_step, restrs)
1615	r_succs = get_n_offset_successes (rep, r_split, r_step, restrs)
1616
1617	if not l_succs:
1618		return True
1619
1620	hyp = syntax.foldr1 (syntax.mk_and, l_succs)
1621	if r_succs:
1622		hyp = syntax.mk_implies (foldr1 (syntax.mk_and, r_succs), hyp)
1623
1624	return rep.test_hyp_whyps (hyp, hyps)
1625
1626def init_loops_to_split (p, restrs):
1627	to_split = loops_to_split (p, restrs)
1628
1629	return [n for n in to_split
1630		if not [n2 for n2 in to_split if n2 != n
1631			and p.is_reachable_from (n2, n)]]
1632
1633def restr_others_both (p, restrs, n, m):
1634	extras = [(sp, vc_double_range (n, m))
1635		for sp in loops_to_split (p, restrs)]
1636	return restrs + tuple (extras)
1637
1638def restr_others_as_necessary (p, n, restrs, init_bound, offs_bound,
1639		skip_loops = []):
1640	extras = [(sp, vc_double_range (init_bound, offs_bound))
1641		for sp in loops_to_split (p, restrs)
1642		if sp not in skip_loops
1643		if p.is_reachable_from (sp, n)]
1644	return restrs + tuple (extras)
1645
1646def loop_no_match_unroll (rep, restrs, hyps, split, other_tag, unroll):
1647	p = rep.p
1648	assert p.node_tags[split][0] != other_tag
1649	restr = ((split, vc_num (unroll)), )
1650	restrs2 = restr_others (p, restr + restrs, 2)
1651	loop_cond = rep.get_pc ((split, restr + restrs))
1652	ret_cond = rep.get_pc (('Ret', restrs2), tag = other_tag)
1653	# loop should be reachable
1654	if rep.test_hyp_whyps (mk_not (loop_cond), hyps):
1655		trace ('Loop weak at %d (unroll count %d).' %
1656			(split, unroll))
1657		return True
1658	# reaching the loop should imply reaching a loop on the other side
1659	hyp = mk_not (mk_and (loop_cond, ret_cond))
1660	if not rep.test_hyp_whyps (hyp, hyps):
1661		trace ('Loop independent at %d (unroll count %d).' %
1662			(split, unroll))
1663		return True
1664	return False
1665
1666def loop_no_match (rep, restrs, hyps, split, other_tag,
1667		check_speculate_ineq = False):
1668	if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 4):
1669		return False
1670	if not loop_no_match_unroll (rep, restrs, hyps, split, other_tag, 8):
1671		return False
1672	if not check_speculate_ineq:
1673		return 'Restr'
1674	spec = get_new_rhs_speculate_ineq (rep.p, restrs, hyps, split)
1675	if not spec:
1676		return 'Restr'
1677	hyp = check.single_induct_resulting_hyp (rep.p, restrs, spec)
1678	hyps2 = hyps + [hyp]
1679	if not loop_no_match_unroll (rep, restrs, hyps2, split, other_tag, 8):
1680		return 'SingleRevInduct'
1681	return 'Restr'
1682
1683last_searcher_results = []
1684
1685def restr_point_name (p, n):
1686	if p.loop_id (n):
1687		return '%s (loop head)' % n
1688	elif p.loop_id (n):
1689		return '%s (in loop %d)' % (n, p.loop_id (n))
1690	else:
1691		return str (n)
1692
1693def fail_searcher (p, restrs, hyps):
1694	return ('Fail Searcher', None)
1695
1696def build_proof_rec (searcher, p, restrs, hyps, name = "problem"):
1697	trace ('doing build proof rec with restrs = %r, hyps = %r' % (restrs, hyps))
1698	if searcher == None:
1699		searcher = default_searcher
1700
1701	(kind, details) = searcher (p, restrs, hyps)
1702	last_searcher_results.append ((p, restrs, hyps, kind, details, name))
1703	del last_searcher_results[:-10]
1704	if kind == 'Restr':
1705		(restr_kind, restr_points) = details
1706		printout ("Discovered that points [%s] can be bounded"
1707			% ', '.join ([restr_point_name (p, n)
1708				for n in restr_points]))
1709		printout ("  (in %s)" % name)
1710		restr_hints = [(n, restr_kind, True) for n in restr_points]
1711		return build_proof_rec_with_restrs (restr_hints,
1712			searcher, p, restrs, hyps, name = name)
1713	elif kind == 'Leaf':
1714		return ProofNode ('Leaf', None, ())
1715	assert kind in ['CaseSplit', 'Split', 'SingleRevInduct'], kind
1716	if kind == 'CaseSplit':
1717		(details, hints) = details
1718	probs = check.proof_subproblems (p, kind, details, restrs, hyps, name)
1719	if kind == 'CaseSplit':
1720		printout ("Decided to case split at %s" % str (details))
1721		printout ("  (in %s)" % name)
1722		restr_hints = [[(n, 'Number', False) for n in hints]
1723			for cases in [0, 1]]
1724	elif kind == 'SingleRevInduct':
1725		printout ('Found a future induction at %s' % str (details[0]))
1726		restr_hints = [[]]
1727	else:
1728		restr_points = check.split_heads (details)
1729		restr_hints = [[(n, rkind, True) for n in restr_points]
1730			for rkind in ['Number', 'Offset']]
1731		printout ("Discovered a loop relation for split points %s"
1732			% list (restr_points))
1733		printout ("  (in %s)" % name)
1734	subpfs = []
1735	for ((restrs, hyps, name), hints) in logic.azip (probs, restr_hints):
1736		printout ('Now doing proof search in %s.' % name)
1737		pf = build_proof_rec_with_restrs (hints, searcher,
1738			p, restrs, hyps, name = name)
1739		subpfs.append (pf)
1740	return ProofNode (kind, details, subpfs)
1741
1742def build_proof_rec_with_restrs (split_hints, searcher, p, restrs,
1743		hyps, name = "problem"):
1744	if not split_hints:
1745		return build_proof_rec (searcher, p, restrs, hyps, name = name)
1746
1747	(sp, kind, must_find) = split_hints[0]
1748	use_hyps = list (hyps)
1749	if p.node_tags[sp][0] != p.pairing.tags[1]:
1750		nrerr_hyp = check.non_r_err_pc_hyp (p.pairing.tags,
1751			restr_others (p, restrs, 2))
1752		use_hyps = use_hyps + [nrerr_hyp]
1753
1754	if p.loop_id (sp):
1755		lim_pair = get_proof_split_limit (p, sp, restrs, use_hyps,
1756			kind, must_find = must_find)
1757	else:
1758		lim_pair = get_proof_visit_restr (p, sp, restrs, use_hyps,
1759			kind, must_find = must_find)
1760
1761	if not lim_pair:
1762		assert not must_find
1763		return build_proof_rec_with_restrs (split_hints[1:],
1764			searcher, p, restrs, hyps, name = name)
1765
1766	(min_v, max_v) = lim_pair
1767	if kind == 'Number':
1768		vc_opts = rep_graph.vc_options (range (min_v, max_v), [])
1769	else:
1770		vc_opts = rep_graph.vc_options ([], range (min_v, max_v))
1771
1772	restrs = restrs + ((sp, vc_opts), )
1773	subproof = build_proof_rec_with_restrs (split_hints[1:],
1774		searcher, p, restrs, hyps, name = name)
1775
1776	return ProofNode ('Restr', (sp, (kind, (min_v, max_v))), [subproof])
1777
1778def get_proof_split_limit (p, sp, restrs, hyps, kind, must_find = False):
1779	limit = get_split_limit (p, sp, restrs, hyps, kind,
1780		must_find = must_find)
1781	if limit == None:
1782		return None
1783	# double-check this limit with a rep constructed without the 'fast' flag
1784	limit = find_split_limit (p, sp, restrs, hyps, kind,
1785		hints = [limit, limit + 1], use_rep = mk_graph_slice (p))
1786	return (0, limit + 1)
1787
1788def get_proof_visit_restr (p, sp, restrs, hyps, kind, must_find = False):
1789	rep = rep_graph.mk_graph_slice (p)
1790	pc = rep.get_pc ((sp, restrs))
1791	if rep.test_hyp_whyps (pc, hyps):
1792		return (1, 2)
1793	elif rep.test_hyp_whyps (mk_not (pc), hyps):
1794		return (0, 1)
1795	else:
1796		assert not must_find
1797		return None
1798
1799def default_searcher (p, restrs, hyps):
1800	# use any handy init splits
1801	res = init_proof_case_split (p, restrs, hyps)
1802	if res:
1803		return res
1804
1805	# detect any un-split loops
1806	to_split_init = init_loops_to_split (p, restrs)
1807	rep = mk_graph_slice (p, fast = True)
1808
1809	l_tag, r_tag = p.pairing.tags
1810	l_to_split = [n for n in to_split_init if p.node_tags[n][0] == l_tag]
1811	r_to_split = [n for n in to_split_init if p.node_tags[n][0] == r_tag]
1812	l_ep = p.get_entry (l_tag)
1813	r_ep = p.get_entry (r_tag)
1814
1815	for r_sp in r_to_split:
1816		trace ('checking loop_no_match at %d' % r_sp, push = 1)
1817		res = loop_no_match (rep, restrs, hyps, r_sp, l_tag,
1818			check_speculate_ineq = True)
1819		if res == 'Restr':
1820			return ('Restr', ('Number', [r_sp]))
1821		elif res == 'SingleRevInduct':
1822			spec = get_new_rhs_speculate_ineq (p, restrs, hyps, r_sp)
1823			assert spec
1824			return ('SingleRevInduct', spec)
1825		trace (' .. done checking loop no match', push = -1)
1826
1827	if l_to_split and not r_to_split:
1828		n = l_to_split[0]
1829		trace ('lhs loop alone, limit must be found.')
1830		return ('Restr', ('Number', [n]))
1831
1832	if l_to_split:
1833		n = l_to_split[0]
1834		trace ('checking lhs loop_no_match at %d' % n, push = 1)
1835		if loop_no_match (rep, restrs, hyps, n, r_tag):
1836			trace ('loop does not match!', push = -1)
1837			return ('Restr', ('Number', [n]))
1838		trace (' .. done checking loop no match', push = -1)
1839
1840		(kind, split) = find_split_loop (p, n, restrs, hyps)
1841		if kind == 'LoopUnroll':
1842			return ('Restr', ('Number', [split]))
1843		return (kind, split)
1844
1845	if r_to_split:
1846		n = r_to_split[0]
1847		trace ('rhs loop alone, limit must be found.')
1848		return ('Restr', ('Number', [n]))
1849
1850	return ('Leaf', None)
1851
1852def use_split_searcher (p, split):
1853	xs = set ([p.loop_id (h) for h in check.split_heads (split)])
1854	def searcher (p, restrs, hyps):
1855		ys = set ([p.loop_id (h)
1856			for h in init_loops_to_split (p, restrs)])
1857		if xs <= ys:
1858			return ('Split', split)
1859		else:
1860			return default_searcher (p, restrs, hyps)
1861	return searcher
1862
1863