1# Copyright 2016-2017 Tobias Grosser 2# 3# Use of this software is governed by the MIT license 4# 5# Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich 6 7import sys 8import isl 9 10# Test that isl objects can be constructed. 11# 12# This tests: 13# - construction from a string 14# - construction from an integer 15# - static constructor without a parameter 16# - conversion construction 17# - construction of empty union set 18# 19# The tests to construct from integers and strings cover functionality that 20# is also tested in the parameter type tests, but here the presence of 21# multiple overloaded constructors and overload resolution is tested. 22# 23def test_constructors(): 24 zero1 = isl.val("0") 25 assert(zero1.is_zero()) 26 27 zero2 = isl.val(0) 28 assert(zero2.is_zero()) 29 30 zero3 = isl.val.zero() 31 assert(zero3.is_zero()) 32 33 bs = isl.basic_set("{ [1] }") 34 result = isl.set("{ [1] }") 35 s = isl.set(bs) 36 assert(s.is_equal(result)) 37 38 us = isl.union_set("{ A[1]; B[2, 3] }") 39 empty = isl.union_set.empty() 40 assert(us.is_equal(us.union(empty))) 41 42# Test integer function parameters for a particular integer value. 43# 44def test_int(i): 45 val_int = isl.val(i) 46 val_str = isl.val(str(i)) 47 assert(val_int.eq(val_str)) 48 49# Test integer function parameters. 50# 51# Verify that extreme values and zero work. 52# 53def test_parameters_int(): 54 test_int(sys.maxsize) 55 test_int(-sys.maxsize - 1) 56 test_int(0) 57 58# Test isl objects parameters. 59# 60# Verify that isl objects can be passed as lvalue and rvalue parameters. 61# Also verify that isl object parameters are automatically type converted if 62# there is an inheritance relation. Finally, test function calls without 63# any additional parameters, apart from the isl object on which 64# the method is called. 65# 66def test_parameters_obj(): 67 a = isl.set("{ [0] }") 68 b = isl.set("{ [1] }") 69 c = isl.set("{ [2] }") 70 expected = isl.set("{ [i] : 0 <= i <= 2 }") 71 72 tmp = a.union(b) 73 res_lvalue_param = tmp.union(c) 74 assert(res_lvalue_param.is_equal(expected)) 75 76 res_rvalue_param = a.union(b).union(c) 77 assert(res_rvalue_param.is_equal(expected)) 78 79 a2 = isl.basic_set("{ [0] }") 80 assert(a.is_equal(a2)) 81 82 two = isl.val(2) 83 half = isl.val("1/2") 84 res_only_this_param = two.inv() 85 assert(res_only_this_param.eq(half)) 86 87# Test different kinds of parameters to be passed to functions. 88# 89# This includes integer and isl object parameters. 90# 91def test_parameters(): 92 test_parameters_int() 93 test_parameters_obj() 94 95# Test that isl objects are returned correctly. 96# 97# This only tests that after combining two objects, the result is successfully 98# returned. 99# 100def test_return_obj(): 101 one = isl.val("1") 102 two = isl.val("2") 103 three = isl.val("3") 104 105 res = one.add(two) 106 107 assert(res.eq(three)) 108 109# Test that integer values are returned correctly. 110# 111def test_return_int(): 112 one = isl.val("1") 113 neg_one = isl.val("-1") 114 zero = isl.val("0") 115 116 assert(one.sgn() > 0) 117 assert(neg_one.sgn() < 0) 118 assert(zero.sgn() == 0) 119 120# Test that isl_bool values are returned correctly. 121# 122# In particular, check the conversion to bool in case of true and false. 123# 124def test_return_bool(): 125 empty = isl.set("{ : false }") 126 univ = isl.set("{ : }") 127 128 b_true = empty.is_empty() 129 b_false = univ.is_empty() 130 131 assert(b_true) 132 assert(not b_false) 133 134# Test that strings are returned correctly. 135# Do so by calling overloaded isl.ast_build.from_expr methods. 136# 137def test_return_string(): 138 context = isl.set("[n] -> { : }") 139 build = isl.ast_build.from_context(context) 140 pw_aff = isl.pw_aff("[n] -> { [n] }") 141 set = isl.set("[n] -> { : n >= 0 }") 142 143 expr = build.expr_from(pw_aff) 144 expected_string = "n" 145 assert(expected_string == expr.to_C_str()) 146 147 expr = build.expr_from(set) 148 expected_string = "n >= 0" 149 assert(expected_string == expr.to_C_str()) 150 151# Test that return values are handled correctly. 152# 153# Test that isl objects, integers, boolean values, and strings are 154# returned correctly. 155# 156def test_return(): 157 test_return_obj() 158 test_return_int() 159 test_return_bool() 160 test_return_string() 161 162# A class that is used to test isl.id.user. 163# 164class S: 165 def __init__(self): 166 self.value = 42 167 168# Test isl.id.user. 169# 170# In particular, check that the object attached to an identifier 171# can be retrieved again. 172# 173def test_user(): 174 id = isl.id("test", 5) 175 id2 = isl.id("test2") 176 id3 = isl.id("S", S()) 177 assert id.user() == 5, f"unexpected user object {id.user()}" 178 assert id2.user() is None, f"unexpected user object {id2.user()}" 179 s = id3.user() 180 assert isinstance(s, S), f"unexpected user object {s}" 181 assert s.value == 42, f"unexpected user object {s}" 182 183# Test that foreach functions are modeled correctly. 184# 185# Verify that closures are correctly called as callback of a 'foreach' 186# function and that variables captured by the closure work correctly. Also 187# check that the foreach function handles exceptions thrown from 188# the closure and that it propagates the exception. 189# 190def test_foreach(): 191 s = isl.set("{ [0]; [1]; [2] }") 192 193 list = [] 194 def add(bs): 195 list.append(bs) 196 s.foreach_basic_set(add) 197 198 assert(len(list) == 3) 199 assert(list[0].is_subset(s)) 200 assert(list[1].is_subset(s)) 201 assert(list[2].is_subset(s)) 202 assert(not list[0].is_equal(list[1])) 203 assert(not list[0].is_equal(list[2])) 204 assert(not list[1].is_equal(list[2])) 205 206 def fail(bs): 207 raise Exception("fail") 208 209 caught = False 210 try: 211 s.foreach_basic_set(fail) 212 except: 213 caught = True 214 assert(caught) 215 216# Test the functionality of "foreach_scc" functions. 217# 218# In particular, test it on a list of elements that can be completely sorted 219# but where two of the elements ("a" and "b") are incomparable. 220# 221def test_foreach_scc(): 222 list = isl.id_list(3) 223 sorted = [isl.id_list(3)] 224 data = { 225 'a' : isl.map("{ [0] -> [1] }"), 226 'b' : isl.map("{ [1] -> [0] }"), 227 'c' : isl.map("{ [i = 0:1] -> [i] }"), 228 } 229 for k, v in data.items(): 230 list = list.add(k) 231 id = data['a'].space().domain().identity_multi_pw_aff_on_domain() 232 def follows(a, b): 233 map = data[b.name()].apply_domain(data[a.name()]) 234 return not map.lex_ge_at(id).is_empty() 235 236 def add_single(scc): 237 assert(scc.size() == 1) 238 sorted[0] = sorted[0].concat(scc) 239 240 list.foreach_scc(follows, add_single) 241 assert(sorted[0].size() == 3) 242 assert(sorted[0].at(0).name() == "b") 243 assert(sorted[0].at(1).name() == "c") 244 assert(sorted[0].at(2).name() == "a") 245 246# Test the functionality of "every" functions. 247# 248# In particular, test the generic functionality and 249# test that exceptions are properly propagated. 250# 251def test_every(): 252 us = isl.union_set("{ A[i]; B[j] }") 253 254 def is_empty(s): 255 return s.is_empty() 256 assert(not us.every_set(is_empty)) 257 258 def is_non_empty(s): 259 return not s.is_empty() 260 assert(us.every_set(is_non_empty)) 261 262 def in_A(s): 263 return s.is_subset(isl.set("{ A[x] }")) 264 assert(not us.every_set(in_A)) 265 266 def not_in_A(s): 267 return not s.is_subset(isl.set("{ A[x] }")) 268 assert(not us.every_set(not_in_A)) 269 270 def fail(s): 271 raise Exception("fail") 272 273 caught = False 274 try: 275 us.ever_set(fail) 276 except: 277 caught = True 278 assert(caught) 279 280# Check basic construction of spaces. 281# 282def test_space(): 283 unit = isl.space.unit() 284 set_space = unit.add_named_tuple("A", 3) 285 map_space = set_space.add_named_tuple("B", 2) 286 287 set = isl.set.universe(set_space) 288 map = isl.map.universe(map_space) 289 assert(set.is_equal(isl.set("{ A[*,*,*] }"))) 290 assert(map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }"))) 291 292# Construct a simple schedule tree with an outer sequence node and 293# a single-dimensional band node in each branch, with one of them 294# marked coincident. 295# 296def construct_schedule_tree(): 297 A = isl.union_set("{ A[i] : 0 <= i < 10 }") 298 B = isl.union_set("{ B[i] : 0 <= i < 20 }") 299 300 node = isl.schedule_node.from_domain(A.union(B)) 301 node = node.child(0) 302 303 filters = isl.union_set_list(A).add(B) 304 node = node.insert_sequence(filters) 305 306 f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]") 307 node = node.child(0) 308 node = node.child(0) 309 node = node.insert_partial_schedule(f_A) 310 node = node.member_set_coincident(0, True) 311 node = node.ancestor(2) 312 313 f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]") 314 node = node.child(1) 315 node = node.child(0) 316 node = node.insert_partial_schedule(f_B) 317 node = node.ancestor(2) 318 319 return node.schedule() 320 321# Test basic schedule tree functionality. 322# 323# In particular, create a simple schedule tree and 324# - check that the root node is a domain node 325# - test map_descendant_bottom_up 326# - test foreach_descendant_top_down 327# - test every_descendant 328# 329def test_schedule_tree(): 330 schedule = construct_schedule_tree() 331 root = schedule.root() 332 333 assert(type(root) == isl.schedule_node_domain) 334 335 count = [0] 336 def inc_count(node): 337 count[0] += 1 338 return node 339 root = root.map_descendant_bottom_up(inc_count) 340 assert(count[0] == 8) 341 342 def fail_map(node): 343 raise Exception("fail") 344 return node 345 caught = False 346 try: 347 root.map_descendant_bottom_up(fail_map) 348 except: 349 caught = True 350 assert(caught) 351 352 count = [0] 353 def inc_count(node): 354 count[0] += 1 355 return True 356 root.foreach_descendant_top_down(inc_count) 357 assert(count[0] == 8) 358 359 count = [0] 360 def inc_count(node): 361 count[0] += 1 362 return False 363 root.foreach_descendant_top_down(inc_count) 364 assert(count[0] == 1) 365 366 def is_not_domain(node): 367 return type(node) != isl.schedule_node_domain 368 assert(root.child(0).every_descendant(is_not_domain)) 369 assert(not root.every_descendant(is_not_domain)) 370 371 def fail(node): 372 raise Exception("fail") 373 caught = False 374 try: 375 root.every_descendant(fail) 376 except: 377 caught = True 378 assert(caught) 379 380 domain = root.domain() 381 filters = [isl.union_set("{}")] 382 def collect_filters(node): 383 if type(node) == isl.schedule_node_filter: 384 filters[0] = filters[0].union(node.filter()) 385 return True 386 root.every_descendant(collect_filters) 387 assert(domain.is_equal(filters[0])) 388 389# Test marking band members for unrolling. 390# "schedule" is the schedule created by construct_schedule_tree. 391# It schedules two statements, with 10 and 20 instances, respectively. 392# Unrolling all band members therefore results in 30 at-domain calls 393# by the AST generator. 394# 395def test_ast_build_unroll(schedule): 396 root = schedule.root() 397 def mark_unroll(node): 398 if type(node) == isl.schedule_node_band: 399 node = node.member_set_ast_loop_unroll(0) 400 return node 401 root = root.map_descendant_bottom_up(mark_unroll) 402 schedule = root.schedule() 403 404 count_ast = [0] 405 def inc_count_ast(node, build): 406 count_ast[0] += 1 407 return node 408 409 build = isl.ast_build() 410 build = build.set_at_each_domain(inc_count_ast) 411 ast = build.node_from(schedule) 412 assert(count_ast[0] == 30) 413 414# Test basic AST generation from a schedule tree. 415# 416# In particular, create a simple schedule tree and 417# - generate an AST from the schedule tree 418# - test at_each_domain 419# - test unrolling 420# 421def test_ast_build(): 422 schedule = construct_schedule_tree() 423 424 count_ast = [0] 425 def inc_count_ast(node, build): 426 count_ast[0] += 1 427 return node 428 429 build = isl.ast_build() 430 build_copy = build.set_at_each_domain(inc_count_ast) 431 ast = build.node_from(schedule) 432 assert(count_ast[0] == 0) 433 count_ast[0] = 0 434 ast = build_copy.node_from(schedule) 435 assert(count_ast[0] == 2) 436 build = build_copy 437 count_ast[0] = 0 438 ast = build.node_from(schedule) 439 assert(count_ast[0] == 2) 440 441 do_fail = True 442 count_ast_fail = [0] 443 def fail_inc_count_ast(node, build): 444 count_ast_fail[0] += 1 445 if do_fail: 446 raise Exception("fail") 447 return node 448 build = isl.ast_build() 449 build = build.set_at_each_domain(fail_inc_count_ast) 450 caught = False 451 try: 452 ast = build.node_from(schedule) 453 except: 454 caught = True 455 assert(caught) 456 assert(count_ast_fail[0] > 0) 457 build_copy = build 458 build_copy = build_copy.set_at_each_domain(inc_count_ast) 459 count_ast[0] = 0 460 ast = build_copy.node_from(schedule) 461 assert(count_ast[0] == 2) 462 count_ast_fail[0] = 0 463 do_fail = False 464 ast = build.node_from(schedule) 465 assert(count_ast_fail[0] == 2) 466 467 test_ast_build_unroll(schedule) 468 469# Test basic AST expression generation from an affine expression. 470# 471def test_ast_build_expr(): 472 pa = isl.pw_aff("[n] -> { [n + 1] }") 473 build = isl.ast_build.from_context(pa.domain()) 474 475 op = build.expr_from(pa) 476 assert(type(op) == isl.ast_expr_op_add) 477 assert(op.n_arg() == 2) 478 479# Test the isl Python interface 480# 481# This includes: 482# - Object construction 483# - Different parameter types 484# - Different return types 485# - isl.id.user 486# - Foreach functions 487# - Foreach SCC function 488# - Every functions 489# - Spaces 490# - Schedule trees 491# - AST generation 492# - AST expression generation 493# 494test_constructors() 495test_parameters() 496test_return() 497test_user() 498test_foreach() 499test_foreach_scc() 500test_every() 501test_space() 502test_schedule_tree() 503test_ast_build() 504test_ast_build_expr() 505