1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7from syntax import (Expr, mk_var, Node, true_term, false_term, 8 fresh_name, word32T, word8T, mk_eq, mk_word32, builtinTs) 9import syntax 10 11from target_objects import functions, pairings, trace, printout 12import sys 13import logic 14from logic import azip 15 16class Abort(Exception): 17 pass 18 19last_problem = [None] 20 21class Problem: 22 def __init__ (self, pairing, name = None): 23 if name == None: 24 name = pairing.name 25 self.name = 'Problem (%s)' % name 26 self.pairing = pairing 27 28 self.nodes = {} 29 self.vs = {} 30 self.next_node_name = 1 31 self.preds = {} 32 self.loop_data = {} 33 self.node_tags = {} 34 self.node_tag_revs = {} 35 self.inline_scripts = {} 36 self.entries = [] 37 self.outputs = {} 38 self.tarjan_order = [] 39 self.loop_var_analysis_cache = {} 40 41 self.known_eqs = {} 42 self.cached_analysis = {} 43 self.hook_tag_hints = {} 44 45 last_problem[0] = self 46 47 def fail_msg (self): 48 return 'FAILED %s (size %05d)' % (self.name, len(self.nodes)) 49 50 def alloc_node (self, tag, detail, loop_id = None, hint = None): 51 name = self.next_node_name 52 self.next_node_name = name + 1 53 54 self.node_tags[name] = (tag, detail) 55 self.node_tag_revs.setdefault ((tag, detail), []) 56 self.node_tag_revs[(tag, detail)].append (name) 57 58 if loop_id != None: 59 self.loop_data[name] = ('Mem', loop_id) 60 61 return name 62 63 def fresh_var (self, name, typ): 64 name = fresh_name (name, self.vs, typ) 65 return mk_var (name, typ) 66 67 def clone_function (self, fun, tag): 68 self.nodes = {} 69 self.vs = syntax.get_vars (fun) 70 for n in fun.reachable_nodes (): 71 self.nodes[n] = fun.nodes[n] 72 detail = (fun.name, n) 73 self.node_tags[n] = (tag, detail) 74 self.node_tag_revs.setdefault ((tag, detail), []) 75 self.node_tag_revs[(tag, detail)].append (n) 76 self.outputs[tag] = fun.outputs 77 self.entries = [(fun.entry, tag, fun.name, fun.inputs)] 78 self.next_node_name = max (self.nodes.keys () + [2]) + 1 79 self.inline_scripts[tag] = [] 80 81 def add_function (self, fun, tag, node_renames, loop_id = None): 82 if not fun.entry: 83 printout ('Aborting %s: underspecified %s' % ( 84 self.name, fun.name)) 85 raise Abort () 86 node_renames.setdefault('Ret', 'Ret') 87 node_renames.setdefault('Err', 'Err') 88 new_node_renames = {} 89 vs = syntax.get_vars (fun) 90 vs = dict ([(v, fresh_name (v, self.vs, vs[v])) for v in vs]) 91 ns = fun.reachable_nodes () 92 check_no_symbols ([fun.nodes[n] for n in ns]) 93 for n in ns: 94 assert n not in node_renames 95 node_renames[n] = self.alloc_node (tag, (fun.name, n), 96 loop_id = loop_id, hint = n) 97 new_node_renames[n] = node_renames[n] 98 for n in ns: 99 self.nodes[node_renames[n]] = syntax.copy_rename ( 100 fun.nodes[n], (vs, node_renames)) 101 102 return (new_node_renames, vs) 103 104 def add_entry_function (self, fun, tag): 105 (ns, vs) = self.add_function (fun, tag, {}) 106 107 entry = ns[fun.entry] 108 args = [(vs[v], typ) for (v, typ) in fun.inputs] 109 rets = [(vs[v], typ) for (v, typ) in fun.outputs] 110 self.entries.append((entry, tag, fun.name, args)) 111 self.outputs[tag] = rets 112 113 self.inline_scripts[tag] = [] 114 115 return (args, rets, entry) 116 117 def get_entry_details (self, tag): 118 [(e, t, fname, args)] = [(e, t, fname, args) 119 for (e, t, fname, args) in self.entries if t == tag] 120 return (e, fname, args) 121 122 def get_entry (self, tag): 123 (e, fname, args) = self.get_entry_details (tag) 124 return e 125 126 def tags (self): 127 return self.outputs.keys () 128 129 def entry_exit_renames (self, tags = None): 130 """computes the rename set of a function's formal parameters 131 to the actual input/output variable names at the various entry 132 and exit points""" 133 mk = lambda xs, ys: dict ([(x[0], y[0]) for (x, y) in 134 azip (xs, ys)]) 135 renames = {} 136 if tags == None: 137 tags = self.tags () 138 for tag in tags: 139 (_, fname, args) = self.get_entry_details (tag) 140 fun = functions[fname] 141 out = self.outputs[tag] 142 renames[tag + '_IN'] = mk (fun.inputs, args) 143 renames[tag + '_OUT'] = mk (fun.outputs, out) 144 return renames 145 146 def redirect_conts (self, reds): 147 for node in self.nodes.itervalues(): 148 if node.kind == 'Cond': 149 node.left = reds.get(node.left, node.left) 150 node.right = reds.get(node.right, node.right) 151 else: 152 node.cont = reds.get(node.cont, node.cont) 153 154 def do_analysis (self): 155 self.cached_analysis.clear () 156 self.compute_preds () 157 self.do_loop_analysis () 158 159 def mk_node_graph (self, node_subset = None): 160 if node_subset == None: 161 node_subset = self.nodes 162 return dict ([(n, [c for c in self.nodes[n].get_conts () 163 if c in node_subset]) 164 for n in node_subset]) 165 166 def do_loop_analysis (self): 167 entries = [e for (e, tag, nm, args) in self.entries] 168 self.loop_data = {} 169 170 graph = self.mk_node_graph () 171 comps = logic.tarjan (graph, entries) 172 self.tarjan_order = [] 173 174 for (head, tail) in comps: 175 self.tarjan_order.append (head) 176 self.tarjan_order.extend (tail) 177 if not tail and head not in graph[head]: 178 continue 179 trace ('Loop (%d, %s)' % (head, tail)) 180 181 loop_set = set (tail) 182 loop_set.add (head) 183 184 r = self.force_single_loop_return (head, loop_set) 185 if r != None: 186 tail.append (r) 187 loop_set.add (r) 188 self.tarjan_order.append (r) 189 self.compute_preds () 190 191 self.loop_data[head] = ('Head', loop_set) 192 for t in tail: 193 self.loop_data[t] = ('Mem', head) 194 195 # put this in first-to-last order. 196 self.tarjan_order.reverse () 197 198 def check_no_inner_loops (self): 199 for loop in self.loop_heads (): 200 check_no_inner_loop (self, loop) 201 202 def force_single_loop_return (self, head, loop_set): 203 rets = [n for n in self.preds[head] if n in loop_set] 204 if (len (rets) == 1 and rets[0] != head and 205 self.nodes[rets[0]].is_noop ()): 206 return None 207 r = self.alloc_node (self.node_tags[head][0], 208 'LoopReturn', loop_id = head) 209 self.nodes[r] = Node ('Basic', head, []) 210 for r2 in rets: 211 self.nodes[r2] = syntax.copy_rename (self.nodes[r2], 212 ({}, {head: r})) 213 return r 214 215 def splittable_points (self, n): 216 """splittable points are points which when removed, the loop 217 'splits' and ceases to be a loop. 218 219 equivalently, the set of splittable points is the intersection 220 of all sub-loops of the loop.""" 221 head = self.loop_id (n) 222 assert head != None 223 k = ('Splittables', head) 224 if k in self.cached_analysis: 225 return self.cached_analysis[k] 226 227 # check if the head point is a split (the inner loop 228 # check does exactly that) 229 if has_inner_loop (self, head): 230 head = logic.get_one_loop_splittable (self, 231 self.loop_body (head)) 232 if head == None: 233 return set () 234 235 splits = self.get_loop_splittables (head) 236 self.cached_analysis[k] = splits 237 return splits 238 239 def get_loop_splittables (self, head): 240 loop_set = self.loop_body (head) 241 splittable = dict ([(n, False) for n in loop_set]) 242 arc = [head] 243 n = head 244 while True: 245 ns = [n2 for n2 in self.nodes[n].get_conts () 246 if n2 in loop_set] 247 ns2 = [x for x in ns if x == head or x not in arc] 248 #n = ns[0] 249 n = ns2[0] 250 arc.append (n) 251 splittable[n] = True 252 if n == head: 253 break 254 last_descs = {} 255 for i in range (len (arc)): 256 last_descs[arc[i]] = i 257 def last_desc (n): 258 if n in last_descs: 259 return last_descs[n] 260 n2s = [n2 for n2 in self.nodes[n].get_conts() 261 if n2 in loop_set] 262 last_descs[n] = None 263 for n2 in n2s: 264 x = last_desc(n2) 265 if last_descs[n] == None or x >= last_descs[n]: 266 last_descs[n] = x 267 return last_descs[n] 268 for i in range (len (arc)): 269 max_arc = max ([last_desc (n) 270 for n in self.nodes[arc[i]].get_conts () 271 if n in loop_set]) 272 for j in range (i + 1, max_arc): 273 splittable[arc[j]] = False 274 return set ([n for n in splittable if splittable[n]]) 275 276 def loop_heads (self): 277 return [n for n in self.loop_data 278 if self.loop_data[n][0] == 'Head'] 279 280 def loop_id (self, n): 281 if n not in self.loop_data: 282 return None 283 elif self.loop_data[n][0] == 'Head': 284 return n 285 else: 286 assert self.loop_data[n][0] == 'Mem' 287 return self.loop_data[n][1] 288 289 def loop_body (self, n): 290 head = self.loop_id (n) 291 return self.loop_data[head][1] 292 293 def compute_preds (self): 294 self.preds = logic.compute_preds (self.nodes) 295 296 def var_dep_outputs (self, n): 297 return self.outputs[self.node_tags[n][0]] 298 299 def compute_var_dependencies (self): 300 if 'var_dependencies' in self.cached_analysis: 301 return self.cached_analysis['var_dependencies'] 302 var_deps = logic.compute_var_deps (self.nodes, 303 self.var_dep_outputs, self.preds) 304 var_deps2 = dict ([(n, dict ([(v, None) 305 for v in var_deps.get (n, [])])) 306 for n in self.nodes]) 307 self.cached_analysis['var_dependencies'] = var_deps2 308 return var_deps2 309 310 def get_loop_var_analysis (self, var_deps, n): 311 head = self.loop_id (n) 312 assert head, n 313 assert n in self.splittable_points (n) 314 loop_sort = tuple (sorted (self.loop_body (head))) 315 node_data = [(self.nodes[n2], sorted (self.preds[n]), 316 sorted (var_deps[n2].keys ())) 317 for n2 in loop_sort] 318 k = (n, loop_sort) 319 data = (node_data, n) 320 if k in self.loop_var_analysis_cache: 321 for (data2, va) in self.loop_var_analysis_cache[k]: 322 if data2 == data: 323 return va 324 va = logic.compute_loop_var_analysis (self, var_deps, n) 325 group = self.loop_var_analysis_cache.setdefault (k, []) 326 group.append ((data, va)) 327 del group[:-10] 328 return va 329 330 def save_graph (self, fname): 331 cols = mk_graph_cols (self.node_tags) 332 save_graph (self.nodes, fname, cols = cols, 333 node_tags = self.node_tags) 334 335 def save_graph_summ (self, fname): 336 node_ids = {} 337 def is_triv (n): 338 if n not in self.nodes: 339 return False 340 if len (self.preds[n]) != 1: 341 return False 342 node = self.nodes[n] 343 if node.kind == 'Basic': 344 return (True, node.cont) 345 elif node.kind == 'Cond' and node.right == 'Err': 346 return (True, node.left) 347 else: 348 return False 349 for n in self.nodes: 350 if n in node_ids: 351 continue 352 ns = [] 353 while is_triv (n): 354 ns.append (n) 355 n = is_triv (n)[1] 356 for n2 in ns: 357 node_ids[n2] = n 358 nodes = {} 359 for n in self.nodes: 360 if is_triv (n): 361 continue 362 nodes[n] = syntax.copy_rename (self.nodes[n], 363 ({}, node_ids)) 364 cols = mk_graph_cols (self.node_tags) 365 save_graph (nodes, fname, cols = cols, 366 node_tags = self.node_tags) 367 368 def serialise (self): 369 ss = ['Problem'] 370 for (n, tag, fname, inputs) in self.entries: 371 xs = ['Entry', '%d' % n, tag, fname, 372 '%d' % len (inputs)] 373 for (nm, typ) in inputs: 374 xs.append (nm) 375 typ.serialise (xs) 376 xs.append ('%d' % len (self.outputs[tag])) 377 for (nm, typ) in self.outputs[tag]: 378 xs.append (nm) 379 typ.serialise (xs) 380 ss.append (' '.join (xs)) 381 for n in self.nodes: 382 xs = ['%d' % n] 383 self.nodes[n].serialise (xs) 384 ss.append (' '.join (xs)) 385 ss.append ('EndProblem') 386 return ss 387 388 def save_serialise (self, fname): 389 ss = self.serialise () 390 f = open (fname, 'w') 391 for s in ss: 392 f.write (s + '\n') 393 f.close () 394 395 def pad_merge_points (self): 396 self.compute_preds () 397 398 arcs = [(pred, n) for n in self.preds 399 if len (self.preds[n]) > 1 400 if n in self.nodes 401 for pred in self.preds[n] 402 if (self.nodes[pred].kind != 'Basic' 403 or self.nodes[pred].upds != [])] 404 405 for (pred, n) in arcs: 406 (tag, _) = self.node_tags[pred] 407 name = self.alloc_node (tag, 'MergePadding') 408 self.nodes[name] = Node ('Basic', n, []) 409 self.nodes[pred] = syntax.copy_rename (self.nodes[pred], 410 ({}, {n: name})) 411 412 def function_call_addrs (self): 413 return [(n, self.nodes[n].fname) 414 for n in self.nodes if self.nodes[n].kind == 'Call'] 415 416 def function_calls (self): 417 return set ([fn for (n, fn) in self.function_call_addrs ()]) 418 419 def get_extensions (self): 420 if 'extensions' in self.cached_analysis: 421 return self.cached_analysis['extensions'] 422 extensions = set () 423 for node in self.nodes.itervalues (): 424 extensions.update (syntax.get_extensions (node)) 425 self.cached_analysis['extensions'] = extensions 426 return extensions 427 428 def replay_inline_script (self, tag, script): 429 for (detail, idx, fname) in script: 430 n = self.node_tag_revs[(tag, detail)][idx] 431 assert self.nodes[n].kind == 'Call', self.nodes[n] 432 assert self.nodes[n].fname == fname, self.nodes[n] 433 inline_at_point (self, n, do_analysis = False) 434 if script: 435 self.do_analysis () 436 437 def is_reachable_from (self, source, target): 438 '''discover if graph addr "target" is reachable 439 from starting node "source"''' 440 k = ('is_reachable_from', source) 441 if k in self.cached_analysis: 442 reachable = self.cached_analysis[k] 443 if target in reachable: 444 return reachable[target] 445 446 reachable = {} 447 visit = [source] 448 while visit: 449 n = visit.pop () 450 if n not in self.nodes: 451 continue 452 for n2 in self.nodes[n].get_conts (): 453 if n2 not in reachable: 454 reachable[n2] = True 455 visit.append (n2) 456 for n in list (self.nodes) + ['Ret', 'Err']: 457 if n not in reachable: 458 reachable[n] = False 459 self.cached_analysis[k] = reachable 460 return reachable[target] 461 462 def is_reachable_without (self, cutpoint, target): 463 '''discover if graph addr "target" is reachable 464 without visiting node "cutpoint" 465 (an oddity: cutpoint itself is considered reachable)''' 466 k = ('is_reachable_without', cutpoint) 467 if k in self.cached_analysis: 468 reachable = self.cached_analysis[k] 469 if target in reachable: 470 return reachable[target] 471 472 reachable = dict ([(self.get_entry (t), True) 473 for t in self.tags ()]) 474 for n in self.tarjan_order + ['Ret', 'Err']: 475 if n in reachable: 476 continue 477 reachable[n] = bool ([pred for pred in self.preds[n] 478 if pred != cutpoint 479 if reachable.get (pred) == True]) 480 self.cached_analysis[k] = reachable 481 return reachable[target] 482 483def deserialise (name, lines): 484 assert lines[0] == 'Problem', lines[0] 485 assert lines[-1] == 'EndProblem', lines[-1] 486 i = 1 487 # not easy to reconstruct pairing 488 p = Problem (pairing = None, name = name) 489 while lines[i].startswith ('Entry'): 490 bits = lines[i].split () 491 en = int (bits[1]) 492 tag = bits[2] 493 fname = bits[3] 494 (n, inputs) = syntax.parse_list (syntax.parse_lval, bits, 4) 495 (n, outputs) = syntax.parse_list (syntax.parse_lval, bits, n) 496 assert n == len (bits), (n, bits) 497 p.entries.append ((en, tag, fname, inputs)) 498 p.outputs[tag] = outputs 499 i += 1 500 for i in range (i, len (lines) - 1): 501 bits = lines[i].split () 502 n = int (bits[0]) 503 node = syntax.parse_node (bits, 1) 504 p.nodes[n] = node 505 return p 506 507# trivia 508 509def check_no_symbols (nodes): 510 import pseudo_compile 511 symbs = pseudo_compile.nodes_symbols (nodes) 512 if not symbs: 513 return 514 printout ('Aborting %s: undefined symbols %s' % (self.name, symbs)) 515 raise Abort () 516 517# printing of problem graphs 518 519def sanitise_str (s): 520 return s.replace ('"', '_').replace ("'", "_").replace (' ', '') 521 522def graph_name (nodes, node_tags, n, prev=None): 523 if type (n) == str: 524 return 't_%s_%d' % (n, prev) 525 if n not in nodes: 526 return 'unknown_%d' % n 527 if n not in node_tags: 528 ident = '%d' % n 529 else: 530 (tag, details) = node_tags[n] 531 if len (details) > 1 and logic.is_int (details[1]): 532 ident = '%d_%s_%s_0x%x' % (n, tag, 533 details[0], details[1]) 534 elif type (details) != str: 535 details = '_'.join (map (str, details)) 536 ident = '%d_%s_%s' % (n, tag, details) 537 else: 538 ident = '%d_%s_%s' % (n, tag, details) 539 ident = sanitise_str (ident) 540 node = nodes[n] 541 if node.kind == 'Call': 542 return 'fcall_%s' % ident 543 if node.kind == 'Cond': 544 return ident 545 if node.kind == 'Basic': 546 return 'ass_%s' % ident 547 assert not 'node kind understood' 548 549def graph_node_tooltip (nodes, n): 550 if n == 'Err': 551 return 'Error point' 552 if n == 'Ret': 553 return 'Return point' 554 node = nodes[n] 555 if node.kind == 'Call': 556 return "%s: call to '%s'" % (n, sanitise_str (node.fname)) 557 if node.kind == 'Cond': 558 return '%s: conditional node' % n 559 if node.kind == 'Basic': 560 var_names = [sanitise_str (x[0][0]) for x in node.upds] 561 return '%s: assignment to [%s]' % (n, ', '.join (var_names)) 562 assert not 'node kind understood' 563 564def graph_edges (nodes, n): 565 node = nodes[n] 566 if node.is_noop (): 567 return [(node.get_conts () [0], 'N')] 568 elif node.kind == 'Cond': 569 return [(node.left, 'T'), (node.right, 'F')] 570 else: 571 return [(node.cont, 'C')] 572 573def get_graph_font (n, col): 574 font = 'fontname = "Arial", fontsize = 20, penwidth=3' 575 if col: 576 font = font + ', color=%s, fontcolor=%s' % (col, col) 577 return font 578 579def get_graph_loops (nodes): 580 graph = dict ([(n, [c for c in nodes[n].get_conts () 581 if type (c) != str]) for n in nodes]) 582 graph['ENTRY'] = list (nodes) 583 comps = logic.tarjan (graph, ['ENTRY']) 584 comp_ids = {} 585 for (head, tail) in comps: 586 comp_ids[head] = head 587 for n in tail: 588 comp_ids[n] = head 589 loops = set ([(n, n2) for n in graph for n2 in graph[n] 590 if comp_ids[n] == comp_ids[n2]]) 591 return loops 592 593def make_graph (nodes, cols, node_tags = {}, entries = []): 594 graph = [] 595 graph.append ('digraph foo {') 596 597 loops = get_graph_loops (nodes) 598 599 for n in nodes: 600 n_nm = graph_name (nodes, node_tags, n) 601 f = get_graph_font (n, cols.get (n)) 602 graph.append ('%s [%s\n label="%s"\n tooltip="%s"];' % (n, 603 f, n_nm, graph_node_tooltip (nodes, n))) 604 for (c, l) in graph_edges (nodes, n): 605 if c in ['Ret', 'Err']: 606 c_nm = '%s_%s' % (c, n) 607 if c == 'Ret': 608 f2 = f + ', shape=doubleoctagon' 609 else: 610 f2 = f + ', shape=Mdiamond' 611 graph.append ('%s [label="%s", %s];' 612 % (c_nm, c, f2)) 613 else: 614 c_nm = c 615 ft = f 616 if (n, c) in loops: 617 ft = f + ', penwidth=6' 618 graph.append ('%s -> %s [label=%s, %s];' % ( 619 n, c_nm, l, ft)) 620 621 for (i, (n, tag, inps)) in enumerate (entries): 622 f = get_graph_font (n, cols.get (n)) 623 nm1 = tag + ' ENTRY_POINT' 624 nm2 = 'entry_point_%d' % i 625 graph.extend (['%s -> %s [%s];' % (nm2, n, f), 626 '%s [label = "%s", shape=none, %s];' % (nm2, nm1, f)]) 627 628 graph.append ('}') 629 return graph 630 631def print_graph (nodes, cols = {}, entries = []): 632 for line in make_graph (nodes, cols, entries): 633 print line 634 635def save_graph (nodes, fname, cols = {}, entries = [], node_tags = {}): 636 f = open (fname, 'w') 637 for line in make_graph (nodes, cols = cols, node_tags = node_tags, 638 entries = entries): 639 f.write (line + '\n') 640 f.close () 641 642def mk_graph_cols (node_tags): 643 known_cols = {'C': "forestgreen", 'ASM_adj': "darkblue", 644 'ASM': "darkorange"} 645 cols = {} 646 for n in node_tags: 647 if node_tags[n][0] in known_cols: 648 cols[n] = known_cols[node_tags[n][0]] 649 return cols 650 651def make_graph_with_eqs (p, invis = False): 652 if invis: 653 invis_s = ', style=invis' 654 else: 655 invis_s = '' 656 cols = mk_graph_cols (p.node_tags) 657 graph = make_graph (p.nodes, cols = cols) 658 graph.pop () 659 for k in p.known_eqs: 660 if k == 'Hyps': 661 continue 662 (n_vc_x, tag_x) = k 663 nm1 = graph_name (p.nodes, p.node_tags, n_vc_x[0]) 664 for (x, n_vc_y, tag_y, y, hyps) in p.known_eqs[k]: 665 nm2 = graph_name (p.nodes, p.node_tags, n_vc_y[0]) 666 graph.extend ([('%s -> %s [ dir = back, color = blue, ' 667 'penwidth = 3, weight = 0 %s ]') 668 % (nm2, nm1, invis_s)]) 669 graph.append ('}') 670 return graph 671 672def save_graph_with_eqs (p, fname = 'diagram.dot', invis = False): 673 graph = make_graph_with_eqs (p, invis = invis) 674 f = open (fname, 'w') 675 for s in graph: 676 f.write (s + '\n') 677 f.close () 678 679def get_problem_vars (p): 680 inout = set.union (* ([set(xs) for xs in p.outputs.itervalues ()] 681 + [set (args) for (_, _, _, args) in p.entries])) 682 683 vs = dict(inout) 684 for node in p.nodes.itervalues(): 685 syntax.get_node_vars(node, vs) 686 return vs 687 688def is_trivial_fun (fun): 689 for node in fun.nodes.itervalues (): 690 if node.is_noop (): 691 continue 692 if node.kind == 'Call': 693 return False 694 elif node.kind == 'Basic': 695 for (lv, v) in node.upds: 696 if v.kind not in ['Var', 'Num']: 697 return False 698 elif node.kind == 'Cond': 699 if node.cond.kind != 'Var' and node.cond not in [ 700 true_term, false_term]: 701 return False 702 return True 703 704last_alt_nodes = [0] 705 706def avail_val (vs, typ): 707 for (nm, typ2) in vs: 708 if typ2 == typ: 709 return mk_var (nm, typ2) 710 return logic.default_val (typ) 711 712def inline_at_point (p, n, do_analysis = True): 713 node = p.nodes[n] 714 if node.kind != 'Call': 715 return 716 717 f_nm = node.fname 718 fun = functions[f_nm] 719 (tag, detail) = p.node_tags[n] 720 idx = p.node_tag_revs[(tag, detail)].index (n) 721 p.inline_scripts[tag].append ((detail, idx, f_nm)) 722 723 trace ('Inlining %s into %s' % (f_nm, p.name)) 724 if n in p.loop_data: 725 trace (' inlining into loop %d!' % p.loop_id (n)) 726 727 ex = p.alloc_node (tag, (f_nm, 'RetToCaller')) 728 729 (ns, vs) = p.add_function (fun, tag, {'Ret': ex}) 730 en = ns[fun.entry] 731 732 inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs] 733 p.nodes[n] = Node ('Basic', en, azip (inp_lvs, node.args)) 734 735 out_vs = [mk_var (vs[v], typ) for (v, typ) in fun.outputs] 736 p.nodes[ex] = Node ('Basic', node.cont, azip (node.rets, out_vs)) 737 738 p.cached_analysis.clear () 739 740 if do_analysis: 741 p.do_analysis () 742 743 trace ('Problem size now %d' % len(p.nodes)) 744 sys.stdin.flush () 745 746 return ns.values () 747 748def loop_body_inner_loops (p, head, loop_body): 749 loop_set_all = set (loop_body) 750 loop_set = loop_set_all - set ([head]) 751 graph = dict([(n, [c for c in p.nodes[n].get_conts () 752 if c in loop_set]) 753 for n in loop_set_all]) 754 755 comps = logic.tarjan (graph, [head]) 756 assert sum ([1 + len (t) for (_, t) in comps]) == len (loop_set_all) 757 return [comp for comp in comps if comp[1]] 758 759def loop_inner_loops (p, head): 760 k = ('inner_loop_set', head) 761 if k in p.cached_analysis: 762 return p.cached_analysis[k] 763 res = loop_body_inner_loops (p, head, p.loop_body (head)) 764 p.cached_analysis[k] = res 765 return res 766 767def loop_heads_including_inner (p): 768 heads = p.loop_heads () 769 check = [(head, p.loop_body (head)) for head in heads] 770 while check: 771 (head, body) = check.pop () 772 comps = loop_body_inner_loops (p, head, body) 773 heads.extend ([head for (head, _) in comps]) 774 check.extend ([(head, [head] + list (body)) 775 for (head, body) in comps]) 776 return heads 777 778def check_no_inner_loop (p, head): 779 subs = loop_inner_loops (p, head) 780 if subs: 781 printout ('Aborting %s, complex loop' % p.name) 782 trace (' sub-loops %s of loop at %s' % (subs, head)) 783 for (h, _) in subs: 784 trace (' head %d tagged %s' % (h, p.node_tags[h])) 785 raise Abort () 786 787def has_inner_loop (p, head): 788 return bool (loop_inner_loops (p, head)) 789 790def fun_has_inner_loop (f): 791 p = f.as_problem (Problem) 792 p.do_analysis () 793 return bool ([head for head in p.loop_heads () 794 if has_inner_loop (p, head)]) 795 796def loop_var_analysis (p, head, tail): 797 # getting the set of variables that go round the loop 798 nodes = set (tail) 799 nodes.add (head) 800 used_vs = set ([]) 801 created_vs_at = {} 802 visit = [] 803 804 def process_node (n, created): 805 if p.nodes[n].is_noop (): 806 lvals = set ([]) 807 else: 808 vs = syntax.get_node_rvals (p.nodes[n]) 809 for rv in vs.iteritems (): 810 if rv not in created: 811 used_vs.add (rv) 812 lvals = set (p.nodes[n].get_lvals ()) 813 814 created = set.union (created, lvals) 815 created_vs_at[n] = created 816 817 visit.extend (p.nodes[n].get_conts ()) 818 819 process_node (head, set ([])) 820 821 while visit: 822 n = visit.pop () 823 if (n not in nodes) or (n in created_vs_at): 824 continue 825 if not all ([pr in created_vs_at for pr in p.preds[n]]): 826 continue 827 828 pre_created = [created_vs_at[pr] for pr in p.preds[n]] 829 process_node (n, set.union (* pre_created)) 830 831 final_pre_created = [created_vs_at[pr] for pr in p.preds[head] 832 if pr in nodes] 833 created = set.union (* final_pre_created) 834 835 loop_vs = set.intersection (created, used_vs) 836 trace ('Loop vars at head: %s' % loop_vs) 837 838 return loop_vs 839 840 841