1# * Copyright 2016, NICTA 2# * 3# * This software may be distributed and modified according to the terms 4# of 5# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 6# * See "LICENSE_BSD2.txt" for details. 7# * 8# * @TAG(NICTA_BSD) 9 10import re 11import graph_refine.syntax as syntax 12import graph_refine.problem as problem 13import graph_refine.stack_logic as stack_logic 14from graph_refine.syntax import true_term, false_term, mk_not 15from graph_refine.check import * 16import graph_refine.search as search 17 18import elf_parser 19import graph_refine.target_objects as target_objects 20 21from imm_utils import * 22from elf_file import * 23from addr_utils import * 24from call_graph_utils import gFuncsCalled 25from dot_utils import toDot,toGraph 26from addr_utils import gToPAddrP,callNodes 27 28def loadCounts(dir_name): 29 #loop_counts.py must contain exactly 1 dict called man_loop_counts 30 context = {} 31 execfile('%s/loop_counts.py' % dir_name,context) 32 33 #we should have a dict of addr -> bound 34 assert 'loops_by_fs' in context 35 lbfs = context['loops_by_fs'] 36 return lbfs 37 38class immFunc (Borg): 39 def __init__(self,elf_fun=None,load_counts=False): 40 Borg.__init__(self) 41 if not elf_fun: 42 return 43 self.elf_fun = elf_fun 44 self.name = elf_fun.name 45 self.addr = elf_fun.addr 46 self.g_f = elf_fun.g_f 47 self.asm_fs = elfFile().asm_fs 48 self.imm_nodes = {} 49 self.bbs = {} 50 self.loaded_loop_counts = False 51 self.parse_only = False 52 self.loop_bounds = {} 53 # dict of f-> loop_heads -> (bound, description) 54 self.loops_by_fs = {} 55 #f -> p_n 56 self.p_entries = {} 57 if load_counts: 58 self.loaded_loops_by_fs = loadCounts(elfFile().dir_name) 59 self.loaded_loop_counts = True 60 61 def process(self): 62 if self.bbs != {}: 63 return 64 self.makeBinGraph() 65 self.loopheads = {} 66 self.findLoopheads() 67 lbfs = self.loops_by_fs 68 if self.loaded_loop_counts: 69 self.bin_loops_by_fs = self.loaded_loops_by_fs 70 print 'loaded loop counts from file' 71 else: 72 #build bin_loops_by_fs from loops_by_fs 73 self.bin_loops_by_fs = {} 74 blbf = self.bin_loops_by_fs 75 for f in lbfs: 76 blbf[f] = {} 77 p = self.f_problems[f] 78 pA = lambda x: phyAddrP(x,p) 79 loops = lbfs[f] 80 for p_head in loops: 81 assert pA(p_head) not in blbf 82 blbf[f][pA(p_head)] = loops[p_head] 83 84 def isBBHead(self,p_nf): 85 if not self.isRealNode(p_nf): 86 return False 87 g_n = self.phyAddr(p_nf) 88 if not type(g_n) == int: 89 return False 90 return g_n in self.bbs 91 92 #bin addr to bb addr 93 def bbAddr(self,addr): 94 bbs = self.bbs 95 for x in bbs: 96 if addr in bbs[x]: 97 return x 98 print 'addr: %x' % addr 99 assert False, 'BB not found !!' 100 101 def toPhyAddrs(self, p_nis): 102 return [self.phyAddr(x) for x in p_nis] 103 104 #find all possible entries of the loop for Chronos 105 def findLoopEntries(self, loop, f): 106 p = self.f_problems[f] 107 head = None 108 lp = [x for x in list(loop) if self.isRealNode( (x,f) )] 109 lpp = [] 110 lp_phys = self.toPhyAddrs([(x,f) for x in lp]) 111 for x in lp: 112 #loop entry, must be 113 #1. a basic block head and 114 #2. has >=1 edge from outside the loop 115 if (x, f ) in self.pf_deadends: 116 ##gotta be halt / branch to halt 117 continue 118 phy_n = self.phyAddr((x,f)) 119 node = self.imm_nodes[phy_n] 120 imm_ext_edges_to = [y for y in node.edges_to if (y not in lp_phys)] 121 if ( len(imm_ext_edges_to) >= 1 and self.isBBHead((x,f)) ): 122 lpp.append(x) 123 return lpp 124 125 def findLoopheads(self): 126 self.imm_loopheads = {} 127 #loopheads = {} 128 loopheads = [] 129 #self.loopheads = loopheads 130 loops_by_fs = self.loops_by_fs 131 for (f,p) in [(f,self.f_problems[f]) for f in self.f_problems]: 132 p.compute_preds() 133 p.do_loop_analysis() 134 l = p.loop_data 135 if p.loop_heads(): 136 loops_by_fs[f] = {} 137 for x in p.loop_heads(): 138 fun,_ = self.pNToFunGN((x,f)) 139 #dodge halt 140 if fun in elfFile().deadend_funcs: 141 continue 142 loopheads.append((x, f)) 143 #the 0 worker_id will get ignored by genLoopHeads. 144 #FIXME: do this properly.. 145 loops_by_fs[f][x] = (2**30,'dummy',0) 146 assert loopheads 147 for p_nf in loopheads: 148 p_n, f = p_nf 149 p = self.f_problems[f] 150 ll = p.loop_data[p_n][1] 151 z = self.findLoopEntries(ll, f) 152 #map from potential heads -> head, hack around chronos 'feature' 153 for q in z: 154 assert q not in self.imm_loopheads, 'one addr cannot have >1 loopcounts !' 155 self.imm_loopheads[self.phyAddr((q,f))] = p_nf 156 157 return 158 159 def firstRealNodes(self,p_nf,visited = None,may_multi=False,may_call=False,skip_ret=False): 160 """ 161 Locate the first real node from, and including, p_addr, 162 or branch targets if it hits a branch before that. 163 Returns a list of p_nf 164 """ 165 elf_fun = self.elf_fun 166 p_n,f = p_nf 167 next_p_nf = p_nf 168 ret = [] 169 if visited == None: 170 #print 'fRN on p_n %d, fun: %s' % (p_n,f) 171 visited = [] 172 173 if p_nf in visited: 174 return [] 175 visited.append(p_nf) 176 177 assert self.pf_deadends != None 178 while True: 179 if self.isRealNode(next_p_nf): 180 return [next_p_nf] 181 next_p_n , next_f, next_p = self.unpackPNF(next_p_nf) 182 if ( next_p_n == 'Ret' and f == self.name): 183 return [('Ret',f)] 184 elif next_p_n == 'Ret': 185 if skip_ret: 186 return [] 187 assert False,'firstRealNodes reached Ret when skip_ret is False' 188 p_node, edges = self.pNodeConts(next_p_nf, may_call=may_call) 189 if edges == []: 190 return [] 191 assert (edges) 192 if len(edges) > 1: 193 assert may_multi 194 for p_e in edges: 195 for ee in self.firstRealNodes(p_e ,visited = list(visited),may_multi=may_multi,may_call=may_call,skip_ret=skip_ret): 196 ret.append(ee) 197 198 return ret 199 else: 200 next_p_nf = edges[0] 201 202 #function p_n belongs to, g_n 203 def pNToFunGN(self,p_nf): 204 p_n,f,p = self.unpackPNF(p_nf) 205 tag = p.node_tags[p_n] 206 _, x = tag 207 f_name, g_n = x 208 return f_name,g_n 209 210 #given p_n is an imm call, return is_taillcall 211 def isCallTailCall(self,p_nf): 212 # suc = p_n_cs[0] 213 g_n = self.phyAddr(p_nf) 214 return elf_parser.isDirectBranch(g_n) 215 216 def isStraightToRetToRoot(self,p_nf): 217 p_n,f,p = self.unpackPNF(p_nf) 218 if p_n == 'Ret' and f == self.name: 219 return True 220 elif p_n == 'Ret': 221 return False 222 if self.isRealNode(p_nf): 223 return False 224 if self.phyAddr(p_nf)=='RetToCaller': 225 return False 226 elif type(p_n) == int: 227 _,pf_conts = self.pNodeConts(p_nf) 228 p_conts = [x[0] for x in pf_conts] 229 if len(p_conts) == 1: 230 return self.isStraightToRetToRoot((p_conts[0],f)) 231 return False 232 233 234 #whether the corresponding imm has a return edge 235 def isImmRootReturn(self,p_nf): 236 p_n,f = p_nf 237 if f != self.name : 238 return False 239 _, pf_conts = self.pNodeConts(p_nf) 240 for x in pf_conts: 241 if self.isStraightToRetToRoot(x): 242 return True 243 return False 244 245 #whether p_n leads straightly to RetToCaller 246 def isStraightToRetToCaller(self,p_nf): 247 p_n,f = p_nf 248 if p_n == 'Ret': 249 if f != self.name: 250 return True 251 else: 252 return False 253 if self.isRealNode(p_nf): 254 return False 255 if self.phyAddr(p_nf)=="RetToCaller": 256 return True 257 elif type(p_n) == int: 258 _,pf_conts = self.pNodeConts(p_nf) 259 p_conts = [x[0] for x in pf_conts] 260 if len(p_conts) == 1: 261 return self.isStraightToRetToCaller((p_conts[0],f)) 262 return False 263 264 #All return except the root one 265 def isImmRetToCaller(self,p_nf): 266 g_n = self.phyAddr(p_nf) 267 p_n,f,p = self.unpackPNF(p_nf) 268 if isCall(p.nodes[p_n]): 269 return False 270 p_node,pf_conts = self.pNodeConts(p_nf) 271 p_conts = [x[0] for x in pf_conts] 272 273 conts = [x for x in p_conts if type(p_n) == int] 274 #print ' p_n %s p_conts %s' % (p_n,p_conts) 275 n_rtc = 0 276 assert self.phyAddr(p_nf) == g_n 277 for pf_cont in pf_conts: 278 cont_n,cont_f = pf_cont 279 if not isCall(self.f_problems[cont_f].nodes[cont_n]): 280 if self.isStraightToRetToCaller(pf_cont): 281 ret = (pf_cont) 282 n_rtc += 1 283 if not ( n_rtc <= 1): 284 #print 'p_n %s g_n %s: n_rtc %s' % (p_n, self.phyAddr(p_n), n_rtc) 285 assert False 286 if n_rtc > 0: 287 return ret 288 return False 289 290 def funName(self,p_nf): 291 p_n,f = p_nf 292 fname = self.f_problems[f].nodes[p_n].fname 293 if '.' in fname: 294 #print 'f: %s' % fname 295 s = [] 296 for c in fname: 297 if c == '.': 298 s.append('_') 299 else: 300 s.append(c) 301 return ''.join(s) 302 return fname 303 304 def makeProblem(self,f): 305 p = problem.Problem(None, 'Functions (%s)' % f) 306 p.add_entry_function(self.asm_fs[f], 'ASM') 307 p.do_analysis() 308 return p 309 310 def isSpecInsFunc(self,f): 311 """ 312 Returns whether f is the name of a special function 313 used to model special instruction 314 """ 315 return f.startswith ("instruction'") 316 317 def makeBinGraph(self): 318 """ 319 Prepare problems for all functions transitively called by self, 320 and turn this into a binary CFG 321 """ 322 self.f_problems = {} 323 if self.name not in elfFile().tcg: 324 print elfFile().tcg.keys() 325 tc_fs = list(elfFile().tcg[self.name]) 326 for f in tc_fs + [self.name]: 327 assert '.' not in f 328 if self.isSpecInsFunc(f): 329 continue 330 p = problem.Problem(None, 'Functions (%s)' % f) 331 p.add_entry_function(self.asm_fs[f], 'ASM') 332 self.f_problems[f] = p 333 #print 'f %s, p.nodes: %d' % (f,len(p.nodes) ) 334 #get its entry 335 assert len(p.entries) == 1 336 self.p_entries[f] = p.entries[0][0] 337 338 print 'all problems generated' 339 self.findAllDeadends() 340 print "all deadends found" 341 #now generate the bin graph 342 343 for f,p in self.f_problems.iteritems(): 344 for p_n in p.nodes: 345 if type(p_n) != int: 346 continue 347 p_nf = (p_n,f) 348 if p_nf in self.pf_deadends: 349 continue 350 if self.isRealNode(p_nf): 351 #print 'adding: %s' % str(p_nf) 352 self.addImmNode(p_nf) 353 354 self.imm_entry = self.phyAddr(self.firstRealNodes((self.p_entries[self.name], self.name ))[0]) 355 #print 'self.imm_entry %x' % self.imm_entry 356 self.bbs = findBBs(self.imm_entry,self) 357 358 359 def findAllDeadends(self): 360 self.pf_deadends = [] 361 pf_deadends = self.pf_deadends 362 self.deadend_g_ns = set() 363 #Halt is a deadend function, and should never be called, it's equivalent to Err for our purpose 364 for dead_f in elfFile().deadend_funcs: 365 print 'dead_f %s' % dead_f 366 deadend_f_g_n = elfFile().funcs[dead_f].addr 367 self.deadend_g_ns.add (deadend_f_g_n) 368 print 'deadend_f_g_n 0x%x' % deadend_f_g_n 369 370 for (f,p) in self.f_problems.iteritems(): 371 for p_n in p.nodes: 372 if self.isDeadend((p_n,f)): 373 pf_deadends.append((p_n,f)) 374 375 def isDeadend(self,p_nf,visited=None): 376 ''' 377 Determine if p_nf (p_n, function) is a deadend node 378 ''' 379 if p_nf in self.pf_deadends: 380 return True 381 p_n, f, p = self.unpackPNF(p_nf) 382 if visited == None: 383 visited = [] 384 if p_n == 'Err': 385 return True 386 if p_n == 'Ret': 387 return False 388 if p_nf in visited: 389 return True 390 if isCall(p.nodes[p_n]): 391 #walk into the callee problem 392 f = self.funName(p_nf) 393 #FIXME: dodge dummy functions 394 if 'instruction' in f: 395 return False 396 if f in elfFile().deadend_funcs: 397 return True 398 p_callee = self.f_problems[f] 399 assert len(p_callee.entries) == 1 400 p_callee_n = p_callee.entries[0][0] 401 return self.isDeadend((p_callee_n,f),visited=visited + [p_nf]) 402 403 if type(p_n) == int and self.phyAddr(p_nf) == 'RetToCaller': 404 return False 405 g_n = self.phyAddr(p_nf) 406 407 if g_n in self.deadend_g_ns: 408 return True 409 410 #note: pNodeConts ensures we stay in the same problem 411 node,fconts = self.pNodeConts(p_nf) 412 conts = [ x[0] for x in fconts] 413 for p_c in conts: 414 assert p_c != p_n 415 if not self.isDeadend( (p_c,f), visited = visited + [p_nf]): 416 return False 417 418 #all ends are dead, thus deadend 419 return True 420 421 def unpackPNF(self,p_nf): 422 p_n,f = p_nf 423 p = self.f_problems[f] 424 return (p_n,f,p) 425 426 def phyAddr (self,p_nf) : 427 p_n, f , p = self.unpackPNF(p_nf) 428 if not isinstance(p_n,int): 429 return p_n 430 _,x = p.node_tags[p_n] 431 if x == 'LoopReturn': 432 return 'LoopReturn' 433 try: 434 f_name,g_addr = x 435 except: 436 print f 437 print 'tags: %s'% str(p.node_tags[p_n]) 438 assert False 439 return g_addr 440 #must not reach Ret 441 def pNodeConts(self, p_nf, no_deadends=False, may_call = False): 442 p_n,f, p = self.unpackPNF(p_nf) 443 p_node = p.nodes[p_n] 444 if isCall(p_node): 445 assert may_call 446 fun_called = self.funName(p_nf) 447 p = self.f_problems[fun_called] 448 entry = self.p_entries[fun_called] 449 pf_conts = [(entry,fun_called)] 450 return p_node, pf_conts 451 assert p_n != 'Ret' 452 p_conts = filter(lambda x: x != 'Err', p_node.get_conts()) 453 if no_deadends: 454 p_conts = filter(lambda x: (x, p_i) not in pi_deadends, p_conts) 455 pf_conts = [(x , f) for x in p_conts] 456 return p_node,pf_conts 457 458 def isRealNode(self,p_nf): 459 p_n,f = p_nf 460 if p_n == 'Ret': 461 return False 462 g_n = self.phyAddr(p_nf) 463 if g_n == 'RetToCaller': 464 return False 465 elif self.isLoopReturn(p_nf): 466 return False 467 elif type(g_n) != int: 468 print 'g_n %s' % str(g_n) 469 assert False, 'g_n expected of typ int' 470 #elif g_n % 4 == 0 and not self.isLoopReturn(p_nf): 471 elif g_n % 4 == 0: 472 assert not self.isLoopReturn(p_nf) 473 return True 474 else: 475 return False 476 477 def isLoopReturn(self,p_nf): 478 p_n,f = p_nf 479 p = self.f_problems[f] 480 tag = p.node_tags[p_n] 481 return tag[1] == 'LoopReturn' 482 483 def addImmNode(self,p_nf): 484 imm_nodes = self.imm_nodes 485 g_n = self.phyAddr(p_nf) 486 p_node,pf_conts = self.pNodeConts(p_nf) 487 p_conts = [x[0] for x in pf_conts] 488 p_n,f,p = self.unpackPNF(p_nf) 489 #print "adding imm_node p_n: %s f: %s" % (p_n,f) 490 if g_n in imm_nodes: 491 #we have been here before 492 node = imm_nodes[g_n] 493 else: 494 node = immNode(g_n,rawVals(g_n)) 495 imm_nodes[g_n] = node 496 497 dont_emit = [] 498 p_imm_return_to_caller_edge = self.isImmRetToCaller(p_nf) 499 call_pn = self.getCallTarg(p_nf) 500 if call_pn: 501 fun_called = self.funName((call_pn, f)) 502 if self.isSpecInsFunc(fun_called): 503 #Hack: go straight to the return node, do nothing else 504 next_addrs = p.nodes[call_pn].get_conts() 505 assert len(next_addrs) == 1 506 next_addr = next_addrs[0] 507 assert next_addr not in ['Ret','Err'] 508 phy_next_addr = self.phyAddr((next_addr,f)) 509 i_e = immEdge(phy_next_addr, emit = True) 510 node.addEdge(i_e) 511 return 512 imm_call = self.parseImmCall(p_nf) 513 assert not p_imm_return_to_caller_edge 514 g_call_targ,g_ret_addr,is_tail_call = imm_call 515 dont_emit.append(g_call_targ) 516 node.addCallRetEdges(g_call_targ, g_ret_addr,is_tail_call) 517 518 elif p_imm_return_to_caller_edge or self.isImmRootReturn(p_nf): 519 node.addRetEdge() 520 521 #add edges to the imm node,ingore Err and halt 522 for p_targ in p_conts: 523 if type(p_targ) == int and (p_targ, f) not in self.pf_deadends: 524 if p_targ == 'Ret': 525 continue 526 edges = self.firstRealNodes((p_targ,f),may_multi=True,may_call=True,skip_ret=True) 527 for p_e in edges : 528 #dodge halt 529 if (p_e) in self.pf_deadends: 530 continue 531 g_e = self.phyAddr(p_e) 532 assert g_e != None 533 if g_e == 'Ret': 534 continue 535 assert g_e != 'Ret' 536 i_e = immEdge(g_e,emit = g_e not in dont_emit) 537 node.addEdge(i_e) 538 539 def retPF(self,call_p_nf): 540 p_n,f,p = self.unpackPNF(call_p_nf) 541 assert len(p.nodes[p_n].get_conts()) == 1 542 return ( (p.nodes[p_n].get_conts())[0] , f) 543 544 def getCallTarg(self, p_nf): 545 p_n,f,p = self.unpackPNF(p_nf) 546 _, pf_conts = self.pNodeConts(p_nf) 547 p_conts = map(lambda x: x[0],pf_conts) 548 #is Imm call iff there is a successor of kind Call in the g graph 549 p_n_cs = filter(lambda p_n_c: 550 type(p_n_c) == int 551 and not self.isLoopReturn(( p_n_c, f)) 552 and isCall(self.gNode((p_n_c,f))) 553 , p_conts) 554 if not p_n_cs: 555 return None 556 assert len(p_n_cs) == 1 557 #return the p_n of the call node 558 return p_n_cs[0] 559 560 def parseImmCall(self,p_nf): 561 """ 562 Returns (entry point to the called function, return addr, is_tailcall) 563 """ 564 call_pn = self.getCallTarg(p_nf) 565 assert call_pn != None 566 567 p_n,f,p = self.unpackPNF(p_nf) 568 #print "p_n: %s, f: %s" % (p_n,f) 569 p_nodes = p.nodes 570 #find the return addr 571 #print "call_pn = %d" % call_pn 572 573 suc = self.firstRealNodes( (call_pn, f) ,may_multi=False,may_call=True) 574 pf_call_targ = suc[0] 575 g_call_targ = self.phyAddr(pf_call_targ) 576 #locate the call return address 577 f_caller, _ = self.pNToFunGN(p_nf) 578 is_tailcall = self.isCallTailCall(p_nf) 579 if not is_tailcall: 580 #return the return addr 581 phy_ret_addr = self.phyAddr(self.retPF((call_pn,f))) 582 else: 583 phy_ret_addr = None 584 585 assert type(phy_ret_addr) == int or is_tailcall, "g_call_targ %s phy_ret_addr %s" % (g_call_targ, phy_ret_addr) 586 #print 'call detected: phy_ret_addr %x' % phy_ret_addr 587 return (g_call_targ, phy_ret_addr,is_tailcall) 588 589 def gNode(self,p_nf): 590 p_n,f,p = self.unpackPNF(p_nf) 591 tag = p.node_tags[p_n] 592 f = tag[1][0] 593 g_n = tag[1][1] 594 return self.asm_fs[f].nodes[g_n] 595