1# * Copyright 2016, NICTA
2# *
3# * This software may be distributed and modified according to the terms
4# of
5# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
6# * See "LICENSE_BSD2.txt" for details.
7# *
8# * @TAG(NICTA_BSD)
9
10import re
11import graph_refine.syntax as syntax
12import graph_refine.problem as problem
13import graph_refine.stack_logic as stack_logic
14from graph_refine.syntax import true_term, false_term, mk_not
15from graph_refine.check import *
16import graph_refine.search as search
17
18import elf_parser
19import graph_refine.target_objects as target_objects
20
21from imm_utils import *
22from elf_file import *
23from addr_utils import *
24from call_graph_utils import gFuncsCalled
25from dot_utils import toDot,toGraph
26from addr_utils import gToPAddrP,callNodes
27
28def loadCounts(dir_name):
29    #loop_counts.py must contain exactly 1 dict called man_loop_counts
30    context = {}
31    execfile('%s/loop_counts.py' % dir_name,context)
32
33    #we should have a dict of addr -> bound
34    assert 'loops_by_fs' in context
35    lbfs = context['loops_by_fs']
36    return lbfs
37
38class immFunc (Borg):
39    def __init__(self,elf_fun=None,load_counts=False):
40        Borg.__init__(self)
41        if not elf_fun:
42            return
43        self.elf_fun = elf_fun
44        self.name = elf_fun.name
45        self.addr = elf_fun.addr
46        self.g_f = elf_fun.g_f
47        self.asm_fs = elfFile().asm_fs
48        self.imm_nodes = {}
49        self.bbs = {}
50        self.loaded_loop_counts = False
51        self.parse_only = False
52        self.loop_bounds = {}
53        # dict of f-> loop_heads -> (bound, description)
54        self.loops_by_fs = {}
55        #f -> p_n
56        self.p_entries = {}
57        if load_counts:
58           self.loaded_loops_by_fs = loadCounts(elfFile().dir_name)
59           self.loaded_loop_counts = True
60
61    def process(self):
62        if self.bbs != {}:
63            return
64        self.makeBinGraph()
65        self.loopheads = {}
66        self.findLoopheads()
67        lbfs = self.loops_by_fs
68        if self.loaded_loop_counts:
69            self.bin_loops_by_fs = self.loaded_loops_by_fs
70            print 'loaded loop counts from file'
71        else:
72            #build bin_loops_by_fs from loops_by_fs
73            self.bin_loops_by_fs = {}
74            blbf = self.bin_loops_by_fs
75            for f in lbfs:
76                blbf[f] = {}
77                p = self.f_problems[f]
78                pA = lambda x: phyAddrP(x,p)
79                loops = lbfs[f]
80                for p_head in loops:
81                    assert pA(p_head) not in blbf
82                    blbf[f][pA(p_head)] = loops[p_head]
83
84    def isBBHead(self,p_nf):
85        if not self.isRealNode(p_nf):
86          return False
87        g_n = self.phyAddr(p_nf)
88        if not type(g_n) == int:
89          return False
90        return g_n in self.bbs
91
92    #bin addr to bb addr
93    def bbAddr(self,addr):
94        bbs = self.bbs
95        for x in bbs:
96          if addr in bbs[x]:
97            return x
98        print 'addr: %x' % addr
99        assert False, 'BB not found !!'
100
101    def toPhyAddrs(self, p_nis):
102      return [self.phyAddr(x) for x in p_nis]
103
104    #find all possible entries of the loop for Chronos
105    def findLoopEntries(self, loop, f):
106        p = self.f_problems[f]
107        head = None
108        lp = [x for x in list(loop) if self.isRealNode( (x,f) )]
109        lpp = []
110        lp_phys = self.toPhyAddrs([(x,f) for x in lp])
111        for x in lp:
112          #loop entry, must be
113          #1. a basic block head and
114          #2. has >=1 edge from outside the loop
115          if (x, f ) in self.pf_deadends:
116              ##gotta be halt / branch to halt
117              continue
118          phy_n = self.phyAddr((x,f))
119          node = self.imm_nodes[phy_n]
120          imm_ext_edges_to = [y for y in node.edges_to if (y not in lp_phys)]
121          if ( len(imm_ext_edges_to) >= 1 and self.isBBHead((x,f)) ):
122            lpp.append(x)
123        return lpp
124
125    def findLoopheads(self):
126        self.imm_loopheads = {}
127        #loopheads = {}
128        loopheads = []
129        #self.loopheads = loopheads
130        loops_by_fs = self.loops_by_fs
131        for (f,p) in [(f,self.f_problems[f]) for f in self.f_problems]:
132            p.compute_preds()
133            p.do_loop_analysis()
134            l = p.loop_data
135            if p.loop_heads():
136                loops_by_fs[f] = {}
137            for x in p.loop_heads():
138              fun,_ = self.pNToFunGN((x,f))
139              #dodge halt
140              if fun in elfFile().deadend_funcs:
141                continue
142              loopheads.append((x, f))
143              #the 0 worker_id will get ignored by genLoopHeads.
144              #FIXME: do this properly..
145              loops_by_fs[f][x] = (2**30,'dummy',0)
146        assert loopheads
147        for p_nf in loopheads:
148          p_n, f = p_nf
149          p = self.f_problems[f]
150          ll = p.loop_data[p_n][1]
151          z = self.findLoopEntries(ll, f)
152          #map from potential heads -> head, hack around chronos 'feature'
153          for q in z:
154            assert q not in self.imm_loopheads, 'one addr cannot have >1 loopcounts !'
155            self.imm_loopheads[self.phyAddr((q,f))] = p_nf
156
157        return
158
159    def firstRealNodes(self,p_nf,visited = None,may_multi=False,may_call=False,skip_ret=False):
160        """
161        Locate the first real node from, and including, p_addr,
162            or branch targets if it hits a branch before that.
163            Returns a list of p_nf
164        """
165        elf_fun = self.elf_fun
166        p_n,f = p_nf
167        next_p_nf = p_nf
168        ret = []
169        if visited == None:
170            #print 'fRN on p_n %d, fun: %s' % (p_n,f)
171            visited = []
172
173        if p_nf in visited:
174          return []
175        visited.append(p_nf)
176
177        assert self.pf_deadends != None
178        while True:
179          if self.isRealNode(next_p_nf):
180             return [next_p_nf]
181          next_p_n , next_f, next_p = self.unpackPNF(next_p_nf)
182          if ( next_p_n == 'Ret' and f == self.name):
183            return [('Ret',f)]
184          elif next_p_n == 'Ret':
185            if skip_ret:
186              return []
187            assert False,'firstRealNodes reached Ret when skip_ret is False'
188          p_node, edges = self.pNodeConts(next_p_nf, may_call=may_call)
189          if edges == []:
190            return []
191          assert (edges)
192          if len(edges) > 1:
193            assert may_multi
194            for p_e in edges:
195                for ee in self.firstRealNodes(p_e ,visited = list(visited),may_multi=may_multi,may_call=may_call,skip_ret=skip_ret):
196                  ret.append(ee)
197
198            return ret
199          else:
200              next_p_nf = edges[0]
201
202    #function p_n belongs to, g_n
203    def pNToFunGN(self,p_nf):
204       p_n,f,p = self.unpackPNF(p_nf)
205       tag = p.node_tags[p_n]
206       _, x = tag
207       f_name, g_n = x
208       return f_name,g_n
209
210    #given p_n is an imm call, return is_taillcall
211    def isCallTailCall(self,p_nf):
212        #    suc = p_n_cs[0]
213        g_n = self.phyAddr(p_nf)
214        return elf_parser.isDirectBranch(g_n)
215
216    def isStraightToRetToRoot(self,p_nf):
217        p_n,f,p = self.unpackPNF(p_nf)
218        if p_n == 'Ret' and f == self.name:
219          return True
220        elif p_n == 'Ret':
221          return False
222        if self.isRealNode(p_nf):
223          return False
224        if self.phyAddr(p_nf)=='RetToCaller':
225          return False
226        elif type(p_n) == int:
227          _,pf_conts = self.pNodeConts(p_nf)
228          p_conts = [x[0] for x in pf_conts]
229          if len(p_conts) == 1:
230            return self.isStraightToRetToRoot((p_conts[0],f))
231        return False
232
233
234    #whether the corresponding imm has a return edge
235    def isImmRootReturn(self,p_nf):
236        p_n,f = p_nf
237        if f != self.name :
238          return False
239        _, pf_conts = self.pNodeConts(p_nf)
240        for x in pf_conts:
241          if self.isStraightToRetToRoot(x):
242            return True
243        return False
244
245    #whether p_n leads straightly to RetToCaller
246    def isStraightToRetToCaller(self,p_nf):
247        p_n,f = p_nf
248        if p_n == 'Ret':
249          if f != self.name:
250            return True
251          else:
252            return False
253        if self.isRealNode(p_nf):
254          return False
255        if self.phyAddr(p_nf)=="RetToCaller":
256          return True
257        elif type(p_n) == int:
258          _,pf_conts = self.pNodeConts(p_nf)
259          p_conts = [x[0] for x in pf_conts]
260          if len(p_conts) == 1:
261            return self.isStraightToRetToCaller((p_conts[0],f))
262        return False
263
264    #All return except the root one
265    def isImmRetToCaller(self,p_nf):
266        g_n = self.phyAddr(p_nf)
267        p_n,f,p = self.unpackPNF(p_nf)
268        if isCall(p.nodes[p_n]):
269          return False
270        p_node,pf_conts = self.pNodeConts(p_nf)
271        p_conts = [x[0] for x in pf_conts]
272
273        conts = [x for x in p_conts if type(p_n) == int]
274        #print '     p_n %s p_conts %s' % (p_n,p_conts)
275        n_rtc = 0
276        assert self.phyAddr(p_nf) == g_n
277        for pf_cont in pf_conts:
278          cont_n,cont_f = pf_cont
279          if not isCall(self.f_problems[cont_f].nodes[cont_n]):
280            if self.isStraightToRetToCaller(pf_cont):
281                ret = (pf_cont)
282                n_rtc += 1
283        if not ( n_rtc <= 1):
284          #print 'p_n %s g_n %s: n_rtc %s' % (p_n, self.phyAddr(p_n), n_rtc)
285          assert False
286        if n_rtc > 0:
287          return ret
288        return False
289
290    def funName(self,p_nf):
291        p_n,f = p_nf
292        fname = self.f_problems[f].nodes[p_n].fname
293        if '.' in fname:
294          #print 'f: %s' % fname
295          s = []
296          for c in fname:
297            if c == '.':
298              s.append('_')
299            else:
300              s.append(c)
301          return ''.join(s)
302        return fname
303
304    def makeProblem(self,f):
305        p = problem.Problem(None, 'Functions (%s)' % f)
306        p.add_entry_function(self.asm_fs[f], 'ASM')
307        p.do_analysis()
308        return p
309
310    def isSpecInsFunc(self,f):
311        """
312        Returns whether f is the name of  a special function
313        used to model special instruction
314        """
315        return f.startswith ("instruction'")
316
317    def makeBinGraph(self):
318        """
319        Prepare problems for all functions transitively called by self,
320        and turn this into a binary CFG
321        """
322        self.f_problems = {}
323        if self.name not in elfFile().tcg:
324            print elfFile().tcg.keys()
325        tc_fs = list(elfFile().tcg[self.name])
326        for f in tc_fs + [self.name]:
327            assert '.' not in f
328            if self.isSpecInsFunc(f):
329                continue
330            p = problem.Problem(None, 'Functions (%s)' % f)
331            p.add_entry_function(self.asm_fs[f], 'ASM')
332            self.f_problems[f] = p
333            #print 'f %s, p.nodes: %d' % (f,len(p.nodes) )
334            #get its entry
335            assert len(p.entries) == 1
336            self.p_entries[f] = p.entries[0][0]
337
338        print 'all problems generated'
339        self.findAllDeadends()
340        print "all deadends found"
341        #now generate the bin graph
342
343        for f,p in self.f_problems.iteritems():
344            for p_n in p.nodes:
345                if type(p_n) != int:
346                    continue
347                p_nf = (p_n,f)
348                if p_nf in self.pf_deadends:
349                    continue
350                if self.isRealNode(p_nf):
351                    #print 'adding: %s' % str(p_nf)
352                    self.addImmNode(p_nf)
353
354        self.imm_entry = self.phyAddr(self.firstRealNodes((self.p_entries[self.name], self.name ))[0])
355        #print 'self.imm_entry %x' % self.imm_entry
356        self.bbs = findBBs(self.imm_entry,self)
357
358
359    def findAllDeadends(self):
360        self.pf_deadends = []
361        pf_deadends = self.pf_deadends
362        self.deadend_g_ns = set()
363        #Halt is a deadend function, and should never be called, it's equivalent to Err for our purpose
364        for dead_f in elfFile().deadend_funcs:
365          print 'dead_f %s' % dead_f
366          deadend_f_g_n = elfFile().funcs[dead_f].addr
367          self.deadend_g_ns.add (deadend_f_g_n)
368          print 'deadend_f_g_n 0x%x' % deadend_f_g_n
369
370        for (f,p) in self.f_problems.iteritems():
371            for p_n in p.nodes:
372                if self.isDeadend((p_n,f)):
373                    pf_deadends.append((p_n,f))
374
375    def isDeadend(self,p_nf,visited=None):
376        '''
377        Determine if p_nf (p_n, function) is a deadend node
378        '''
379        if p_nf in self.pf_deadends:
380          return True
381        p_n, f, p = self.unpackPNF(p_nf)
382        if visited == None:
383          visited = []
384        if p_n == 'Err':
385          return True
386        if p_n == 'Ret':
387          return False
388        if p_nf in visited:
389          return True
390        if isCall(p.nodes[p_n]):
391            #walk into the callee problem
392            f = self.funName(p_nf)
393            #FIXME: dodge dummy functions
394            if 'instruction' in f:
395                return False
396            if f in elfFile().deadend_funcs:
397              return True
398            p_callee = self.f_problems[f]
399            assert len(p_callee.entries) == 1
400            p_callee_n = p_callee.entries[0][0]
401            return self.isDeadend((p_callee_n,f),visited=visited + [p_nf])
402
403        if type(p_n) == int and self.phyAddr(p_nf) == 'RetToCaller':
404          return False
405        g_n = self.phyAddr(p_nf)
406
407        if g_n in self.deadend_g_ns:
408          return True
409
410        #note: pNodeConts ensures we stay in the same problem
411        node,fconts = self.pNodeConts(p_nf)
412        conts = [ x[0] for x in fconts]
413        for p_c in conts:
414          assert p_c != p_n
415          if not self.isDeadend( (p_c,f), visited = visited + [p_nf]):
416            return False
417
418        #all ends are dead, thus deadend
419        return True
420
421    def unpackPNF(self,p_nf):
422        p_n,f = p_nf
423        p = self.f_problems[f]
424        return (p_n,f,p)
425
426    def phyAddr (self,p_nf) :
427        p_n, f , p = self.unpackPNF(p_nf)
428        if not isinstance(p_n,int):
429            return p_n
430        _,x = p.node_tags[p_n]
431        if x == 'LoopReturn':
432            return 'LoopReturn'
433        try:
434            f_name,g_addr = x
435        except:
436            print f
437            print 'tags: %s'%  str(p.node_tags[p_n])
438            assert False
439        return g_addr
440    #must not reach Ret
441    def pNodeConts(self, p_nf, no_deadends=False, may_call = False):
442        p_n,f, p = self.unpackPNF(p_nf)
443        p_node = p.nodes[p_n]
444        if isCall(p_node):
445          assert may_call
446          fun_called = self.funName(p_nf)
447          p = self.f_problems[fun_called]
448          entry = self.p_entries[fun_called]
449          pf_conts = [(entry,fun_called)]
450          return p_node, pf_conts
451        assert p_n != 'Ret'
452        p_conts = filter(lambda x: x != 'Err', p_node.get_conts())
453        if no_deadends:
454            p_conts = filter(lambda x: (x, p_i) not in pi_deadends, p_conts)
455        pf_conts = [(x , f) for x in p_conts]
456        return p_node,pf_conts
457
458    def isRealNode(self,p_nf):
459        p_n,f = p_nf
460        if p_n == 'Ret':
461          return False
462        g_n = self.phyAddr(p_nf)
463        if g_n == 'RetToCaller':
464            return False
465        elif self.isLoopReturn(p_nf):
466            return False
467        elif type(g_n) != int:
468            print 'g_n %s' % str(g_n)
469            assert False, 'g_n expected of typ int'
470        #elif g_n % 4 == 0 and not self.isLoopReturn(p_nf):
471        elif g_n % 4 == 0:
472            assert not self.isLoopReturn(p_nf)
473            return True
474        else:
475            return False
476
477    def isLoopReturn(self,p_nf):
478        p_n,f = p_nf
479        p = self.f_problems[f]
480        tag = p.node_tags[p_n]
481        return tag[1] == 'LoopReturn'
482
483    def addImmNode(self,p_nf):
484        imm_nodes = self.imm_nodes
485        g_n = self.phyAddr(p_nf)
486        p_node,pf_conts = self.pNodeConts(p_nf)
487        p_conts = [x[0] for x in pf_conts]
488        p_n,f,p = self.unpackPNF(p_nf)
489        #print "adding imm_node p_n: %s f: %s" % (p_n,f)
490        if g_n in imm_nodes:
491          #we have been here before
492          node = imm_nodes[g_n]
493        else:
494          node = immNode(g_n,rawVals(g_n))
495          imm_nodes[g_n] = node
496
497        dont_emit = []
498        p_imm_return_to_caller_edge = self.isImmRetToCaller(p_nf)
499        call_pn =  self.getCallTarg(p_nf)
500        if call_pn:
501            fun_called = self.funName((call_pn, f))
502            if self.isSpecInsFunc(fun_called):
503                #Hack: go straight to the return node, do nothing else
504                next_addrs = p.nodes[call_pn].get_conts()
505                assert len(next_addrs) == 1
506                next_addr = next_addrs[0]
507                assert next_addr not in ['Ret','Err']
508                phy_next_addr = self.phyAddr((next_addr,f))
509                i_e = immEdge(phy_next_addr, emit = True)
510                node.addEdge(i_e)
511                return
512            imm_call = self.parseImmCall(p_nf)
513            assert not p_imm_return_to_caller_edge
514            g_call_targ,g_ret_addr,is_tail_call = imm_call
515            dont_emit.append(g_call_targ)
516            node.addCallRetEdges(g_call_targ, g_ret_addr,is_tail_call)
517
518        elif p_imm_return_to_caller_edge or self.isImmRootReturn(p_nf):
519            node.addRetEdge()
520
521        #add edges to the imm node,ingore Err and halt
522        for p_targ in p_conts:
523          if type(p_targ) == int and (p_targ, f) not in self.pf_deadends:
524            if p_targ == 'Ret':
525              continue
526            edges = self.firstRealNodes((p_targ,f),may_multi=True,may_call=True,skip_ret=True)
527            for p_e in edges :
528              #dodge halt
529              if (p_e) in self.pf_deadends:
530                continue
531              g_e = self.phyAddr(p_e)
532              assert g_e != None
533              if g_e == 'Ret':
534                continue
535              assert g_e != 'Ret'
536              i_e = immEdge(g_e,emit = g_e not in dont_emit)
537              node.addEdge(i_e)
538
539    def retPF(self,call_p_nf):
540        p_n,f,p = self.unpackPNF(call_p_nf)
541        assert len(p.nodes[p_n].get_conts()) == 1
542        return ( (p.nodes[p_n].get_conts())[0] , f)
543
544    def getCallTarg(self, p_nf):
545        p_n,f,p = self.unpackPNF(p_nf)
546        _, pf_conts = self.pNodeConts(p_nf)
547        p_conts = map(lambda x: x[0],pf_conts)
548        #is Imm call iff there is a successor of kind Call in the g graph
549        p_n_cs = filter(lambda p_n_c:
550                        type(p_n_c) == int
551                        and not self.isLoopReturn(( p_n_c, f))
552                        and isCall(self.gNode((p_n_c,f)))
553                        , p_conts)
554        if not p_n_cs:
555          return None
556        assert len(p_n_cs) == 1
557        #return the p_n of the call node
558        return p_n_cs[0]
559
560    def parseImmCall(self,p_nf):
561        """
562        Returns (entry point to the called function, return addr, is_tailcall)
563        """
564        call_pn = self.getCallTarg(p_nf)
565        assert call_pn != None
566
567        p_n,f,p = self.unpackPNF(p_nf)
568        #print "p_n: %s, f: %s" % (p_n,f)
569        p_nodes = p.nodes
570        #find the return addr
571        #print "call_pn = %d" % call_pn
572
573        suc = self.firstRealNodes( (call_pn, f) ,may_multi=False,may_call=True)
574        pf_call_targ = suc[0]
575        g_call_targ = self.phyAddr(pf_call_targ)
576        #locate the call return address
577        f_caller, _ = self.pNToFunGN(p_nf)
578        is_tailcall = self.isCallTailCall(p_nf)
579        if not is_tailcall:
580            #return the return addr
581            phy_ret_addr = self.phyAddr(self.retPF((call_pn,f)))
582        else:
583            phy_ret_addr = None
584
585        assert type(phy_ret_addr) == int or is_tailcall, "g_call_targ %s phy_ret_addr %s" % (g_call_targ, phy_ret_addr)
586          #print 'call detected: phy_ret_addr %x' % phy_ret_addr
587        return (g_call_targ, phy_ret_addr,is_tailcall)
588
589    def gNode(self,p_nf):
590        p_n,f,p = self.unpackPNF(p_nf)
591        tag = p.node_tags[p_n]
592        f = tag[1][0]
593        g_n = tag[1][1]
594        return self.asm_fs[f].nodes[g_n]
595