1# Copyright 2016-2017 Tobias Grosser
2#
3# Use of this software is governed by the MIT license
4#
5# Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich
6
7import sys
8import isl
9
10# Test that isl objects can be constructed.
11#
12# This tests:
13#  - construction from a string
14#  - construction from an integer
15#  - static constructor without a parameter
16#  - conversion construction
17#  - construction of empty union set
18#
19#  The tests to construct from integers and strings cover functionality that
20#  is also tested in the parameter type tests, but here the presence of
21#  multiple overloaded constructors and overload resolution is tested.
22#
23def test_constructors():
24	zero1 = isl.val("0")
25	assert(zero1.is_zero())
26
27	zero2 = isl.val(0)
28	assert(zero2.is_zero())
29
30	zero3 = isl.val.zero()
31	assert(zero3.is_zero())
32
33	bs = isl.basic_set("{ [1] }")
34	result = isl.set("{ [1] }")
35	s = isl.set(bs)
36	assert(s.is_equal(result))
37
38	us = isl.union_set("{ A[1]; B[2, 3] }")
39	empty = isl.union_set.empty()
40	assert(us.is_equal(us.union(empty)))
41
42# Test integer function parameters for a particular integer value.
43#
44def test_int(i):
45	val_int = isl.val(i)
46	val_str = isl.val(str(i))
47	assert(val_int.eq(val_str))
48
49# Test integer function parameters.
50#
51# Verify that extreme values and zero work.
52#
53def test_parameters_int():
54	test_int(sys.maxsize)
55	test_int(-sys.maxsize - 1)
56	test_int(0)
57
58# Test isl objects parameters.
59#
60# Verify that isl objects can be passed as lvalue and rvalue parameters.
61# Also verify that isl object parameters are automatically type converted if
62# there is an inheritance relation. Finally, test function calls without
63# any additional parameters, apart from the isl object on which
64# the method is called.
65#
66def test_parameters_obj():
67	a = isl.set("{ [0] }")
68	b = isl.set("{ [1] }")
69	c = isl.set("{ [2] }")
70	expected = isl.set("{ [i] : 0 <= i <= 2 }")
71
72	tmp = a.union(b)
73	res_lvalue_param = tmp.union(c)
74	assert(res_lvalue_param.is_equal(expected))
75
76	res_rvalue_param = a.union(b).union(c)
77	assert(res_rvalue_param.is_equal(expected))
78
79	a2 = isl.basic_set("{ [0] }")
80	assert(a.is_equal(a2))
81
82	two = isl.val(2)
83	half = isl.val("1/2")
84	res_only_this_param = two.inv()
85	assert(res_only_this_param.eq(half))
86
87# Test different kinds of parameters to be passed to functions.
88#
89# This includes integer and isl object parameters.
90#
91def test_parameters():
92	test_parameters_int()
93	test_parameters_obj()
94
95# Test that isl objects are returned correctly.
96#
97# This only tests that after combining two objects, the result is successfully
98# returned.
99#
100def test_return_obj():
101	one = isl.val("1")
102	two = isl.val("2")
103	three = isl.val("3")
104
105	res = one.add(two)
106
107	assert(res.eq(three))
108
109# Test that integer values are returned correctly.
110#
111def test_return_int():
112	one = isl.val("1")
113	neg_one = isl.val("-1")
114	zero = isl.val("0")
115
116	assert(one.sgn() > 0)
117	assert(neg_one.sgn() < 0)
118	assert(zero.sgn() == 0)
119
120# Test that isl_bool values are returned correctly.
121#
122# In particular, check the conversion to bool in case of true and false.
123#
124def test_return_bool():
125	empty = isl.set("{ : false }")
126	univ = isl.set("{ : }")
127
128	b_true = empty.is_empty()
129	b_false = univ.is_empty()
130
131	assert(b_true)
132	assert(not b_false)
133
134# Test that strings are returned correctly.
135# Do so by calling overloaded isl.ast_build.from_expr methods.
136#
137def test_return_string():
138	context = isl.set("[n] -> { : }")
139	build = isl.ast_build.from_context(context)
140	pw_aff = isl.pw_aff("[n] -> { [n] }")
141	set = isl.set("[n] -> { : n >= 0 }")
142
143	expr = build.expr_from(pw_aff)
144	expected_string = "n"
145	assert(expected_string == expr.to_C_str())
146
147	expr = build.expr_from(set)
148	expected_string = "n >= 0"
149	assert(expected_string == expr.to_C_str())
150
151# Test that return values are handled correctly.
152#
153# Test that isl objects, integers, boolean values, and strings are
154# returned correctly.
155#
156def test_return():
157	test_return_obj()
158	test_return_int()
159	test_return_bool()
160	test_return_string()
161
162# A class that is used to test isl.id.user.
163#
164class S:
165	def __init__(self):
166		self.value = 42
167
168# Test isl.id.user.
169#
170# In particular, check that the object attached to an identifier
171# can be retrieved again.
172#
173def test_user():
174	id = isl.id("test", 5)
175	id2 = isl.id("test2")
176	id3 = isl.id("S", S())
177	assert id.user() == 5, f"unexpected user object {id.user()}"
178	assert id2.user() is None, f"unexpected user object {id2.user()}"
179	s = id3.user()
180	assert isinstance(s, S), f"unexpected user object {s}"
181	assert s.value == 42, f"unexpected user object {s}"
182
183# Test that foreach functions are modeled correctly.
184#
185# Verify that closures are correctly called as callback of a 'foreach'
186# function and that variables captured by the closure work correctly. Also
187# check that the foreach function handles exceptions thrown from
188# the closure and that it propagates the exception.
189#
190def test_foreach():
191	s = isl.set("{ [0]; [1]; [2] }")
192
193	list = []
194	def add(bs):
195		list.append(bs)
196	s.foreach_basic_set(add)
197
198	assert(len(list) == 3)
199	assert(list[0].is_subset(s))
200	assert(list[1].is_subset(s))
201	assert(list[2].is_subset(s))
202	assert(not list[0].is_equal(list[1]))
203	assert(not list[0].is_equal(list[2]))
204	assert(not list[1].is_equal(list[2]))
205
206	def fail(bs):
207		raise Exception("fail")
208
209	caught = False
210	try:
211		s.foreach_basic_set(fail)
212	except:
213		caught = True
214	assert(caught)
215
216# Test the functionality of "foreach_scc" functions.
217#
218# In particular, test it on a list of elements that can be completely sorted
219# but where two of the elements ("a" and "b") are incomparable.
220#
221def test_foreach_scc():
222	list = isl.id_list(3)
223	sorted = [isl.id_list(3)]
224	data = {
225		'a' : isl.map("{ [0] -> [1] }"),
226		'b' : isl.map("{ [1] -> [0] }"),
227		'c' : isl.map("{ [i = 0:1] -> [i] }"),
228	}
229	for k, v in data.items():
230		list = list.add(k)
231	id = data['a'].space().domain().identity_multi_pw_aff_on_domain()
232	def follows(a, b):
233		map = data[b.name()].apply_domain(data[a.name()])
234		return not map.lex_ge_at(id).is_empty()
235
236	def add_single(scc):
237		assert(scc.size() == 1)
238		sorted[0] = sorted[0].concat(scc)
239
240	list.foreach_scc(follows, add_single)
241	assert(sorted[0].size() == 3)
242	assert(sorted[0].at(0).name() == "b")
243	assert(sorted[0].at(1).name() == "c")
244	assert(sorted[0].at(2).name() == "a")
245
246# Test the functionality of "every" functions.
247#
248# In particular, test the generic functionality and
249# test that exceptions are properly propagated.
250#
251def test_every():
252	us = isl.union_set("{ A[i]; B[j] }")
253
254	def is_empty(s):
255		return s.is_empty()
256	assert(not us.every_set(is_empty))
257
258	def is_non_empty(s):
259		return not s.is_empty()
260	assert(us.every_set(is_non_empty))
261
262	def in_A(s):
263		return s.is_subset(isl.set("{ A[x] }"))
264	assert(not us.every_set(in_A))
265
266	def not_in_A(s):
267		return not s.is_subset(isl.set("{ A[x] }"))
268	assert(not us.every_set(not_in_A))
269
270	def fail(s):
271		raise Exception("fail")
272
273	caught = False
274	try:
275		us.ever_set(fail)
276	except:
277		caught = True
278	assert(caught)
279
280# Check basic construction of spaces.
281#
282def test_space():
283	unit = isl.space.unit()
284	set_space = unit.add_named_tuple("A", 3)
285	map_space = set_space.add_named_tuple("B", 2)
286
287	set = isl.set.universe(set_space)
288	map = isl.map.universe(map_space)
289	assert(set.is_equal(isl.set("{ A[*,*,*] }")))
290	assert(map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }")))
291
292# Construct a simple schedule tree with an outer sequence node and
293# a single-dimensional band node in each branch, with one of them
294# marked coincident.
295#
296def construct_schedule_tree():
297	A = isl.union_set("{ A[i] : 0 <= i < 10 }")
298	B = isl.union_set("{ B[i] : 0 <= i < 20 }")
299
300	node = isl.schedule_node.from_domain(A.union(B))
301	node = node.child(0)
302
303	filters = isl.union_set_list(A).add(B)
304	node = node.insert_sequence(filters)
305
306	f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]")
307	node = node.child(0)
308	node = node.child(0)
309	node = node.insert_partial_schedule(f_A)
310	node = node.member_set_coincident(0, True)
311	node = node.ancestor(2)
312
313	f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]")
314	node = node.child(1)
315	node = node.child(0)
316	node = node.insert_partial_schedule(f_B)
317	node = node.ancestor(2)
318
319	return node.schedule()
320
321# Test basic schedule tree functionality.
322#
323# In particular, create a simple schedule tree and
324# - check that the root node is a domain node
325# - test map_descendant_bottom_up
326# - test foreach_descendant_top_down
327# - test every_descendant
328#
329def test_schedule_tree():
330	schedule = construct_schedule_tree()
331	root = schedule.root()
332
333	assert(type(root) == isl.schedule_node_domain)
334
335	count = [0]
336	def inc_count(node):
337		count[0] += 1
338		return node
339	root = root.map_descendant_bottom_up(inc_count)
340	assert(count[0] == 8)
341
342	def fail_map(node):
343		raise Exception("fail")
344		return node
345	caught = False
346	try:
347		root.map_descendant_bottom_up(fail_map)
348	except:
349		caught = True
350	assert(caught)
351
352	count = [0]
353	def inc_count(node):
354		count[0] += 1
355		return True
356	root.foreach_descendant_top_down(inc_count)
357	assert(count[0] == 8)
358
359	count = [0]
360	def inc_count(node):
361		count[0] += 1
362		return False
363	root.foreach_descendant_top_down(inc_count)
364	assert(count[0] == 1)
365
366	def is_not_domain(node):
367		return type(node) != isl.schedule_node_domain
368	assert(root.child(0).every_descendant(is_not_domain))
369	assert(not root.every_descendant(is_not_domain))
370
371	def fail(node):
372		raise Exception("fail")
373	caught = False
374	try:
375		root.every_descendant(fail)
376	except:
377		caught = True
378	assert(caught)
379
380	domain = root.domain()
381	filters = [isl.union_set("{}")]
382	def collect_filters(node):
383		if type(node) == isl.schedule_node_filter:
384			filters[0] = filters[0].union(node.filter())
385		return True
386	root.every_descendant(collect_filters)
387	assert(domain.is_equal(filters[0]))
388
389# Test marking band members for unrolling.
390# "schedule" is the schedule created by construct_schedule_tree.
391# It schedules two statements, with 10 and 20 instances, respectively.
392# Unrolling all band members therefore results in 30 at-domain calls
393# by the AST generator.
394#
395def test_ast_build_unroll(schedule):
396	root = schedule.root()
397	def mark_unroll(node):
398		if type(node) == isl.schedule_node_band:
399			node = node.member_set_ast_loop_unroll(0)
400		return node
401	root = root.map_descendant_bottom_up(mark_unroll)
402	schedule = root.schedule()
403
404	count_ast = [0]
405	def inc_count_ast(node, build):
406		count_ast[0] += 1
407		return node
408
409	build = isl.ast_build()
410	build = build.set_at_each_domain(inc_count_ast)
411	ast = build.node_from(schedule)
412	assert(count_ast[0] == 30)
413
414# Test basic AST generation from a schedule tree.
415#
416# In particular, create a simple schedule tree and
417# - generate an AST from the schedule tree
418# - test at_each_domain
419# - test unrolling
420#
421def test_ast_build():
422	schedule = construct_schedule_tree()
423
424	count_ast = [0]
425	def inc_count_ast(node, build):
426		count_ast[0] += 1
427		return node
428
429	build = isl.ast_build()
430	build_copy = build.set_at_each_domain(inc_count_ast)
431	ast = build.node_from(schedule)
432	assert(count_ast[0] == 0)
433	count_ast[0] = 0
434	ast = build_copy.node_from(schedule)
435	assert(count_ast[0] == 2)
436	build = build_copy
437	count_ast[0] = 0
438	ast = build.node_from(schedule)
439	assert(count_ast[0] == 2)
440
441	do_fail = True
442	count_ast_fail = [0]
443	def fail_inc_count_ast(node, build):
444		count_ast_fail[0] += 1
445		if do_fail:
446			raise Exception("fail")
447		return node
448	build = isl.ast_build()
449	build = build.set_at_each_domain(fail_inc_count_ast)
450	caught = False
451	try:
452		ast = build.node_from(schedule)
453	except:
454		caught = True
455	assert(caught)
456	assert(count_ast_fail[0] > 0)
457	build_copy = build
458	build_copy = build_copy.set_at_each_domain(inc_count_ast)
459	count_ast[0] = 0
460	ast = build_copy.node_from(schedule)
461	assert(count_ast[0] == 2)
462	count_ast_fail[0] = 0
463	do_fail = False
464	ast = build.node_from(schedule)
465	assert(count_ast_fail[0] == 2)
466
467	test_ast_build_unroll(schedule)
468
469# Test basic AST expression generation from an affine expression.
470#
471def test_ast_build_expr():
472	pa = isl.pw_aff("[n] -> { [n + 1] }")
473	build = isl.ast_build.from_context(pa.domain())
474
475	op = build.expr_from(pa)
476	assert(type(op) == isl.ast_expr_op_add)
477	assert(op.n_arg() == 2)
478
479# Test the isl Python interface
480#
481# This includes:
482#  - Object construction
483#  - Different parameter types
484#  - Different return types
485#  - isl.id.user
486#  - Foreach functions
487#  - Foreach SCC function
488#  - Every functions
489#  - Spaces
490#  - Schedule trees
491#  - AST generation
492#  - AST expression generation
493#
494test_constructors()
495test_parameters()
496test_return()
497test_user()
498test_foreach()
499test_foreach_scc()
500test_every()
501test_space()
502test_schedule_tree()
503test_ast_build()
504test_ast_build_expr()
505