1# * Copyright 2016, NICTA
2# *
3# * This software may be distributed and modified according to the terms
4# of
5# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
6# * See "LICENSE_BSD2.txt" for details.
7# *
8# * @TAG(NICTA_BSD)
9
10#!/usr/bin/env python
11
12import os, re, sys, copy
13from subprocess import Popen, PIPE
14from elf_correlate import immFunc
15from elf_file import elfFile
16from addr_utils import callNodes,phyAddrP
17import bench
18import cplex
19import graph_refine.trace_refute as trace_refute
20import convert_loop_bounds
21from graph_refine.trace_refute import parse_num_arrow_list
22
23global bb_addr_to_ids
24global id_to_bb_addr
25global id_to_context
26global bb_count
27global edge_count
28global bb_addrs_in_loops
29global tcfg_paths
30bb_addr_to_ids = {}
31id_to_context = {}
32bb_count = {}
33edge_count = {}
34bb_addrs_in_loops = []
35tcfg_paths = {}
36id_to_bb_addr = {}
37
38def cleanGlobalStates():
39    global bb_addr_to_ids
40    global id_to_bb_addr
41    global id_to_context
42    global bb_count
43    global edge_count
44    global bb_addrs_in_loops
45    global tcfg_paths
46
47#we assume the ilp problem stays the same modulo extra refutes
48    bb_addr_to_ids = {}
49    id_to_context = {}
50    bb_count = {}
51    edge_count = {}
52    bb_addrs_in_loops = []
53    tcfg_paths = {}
54    id_to_bb_addr = {}
55    print 'conflict.py: global states cleaned'
56
57def read_variables(input_filename):
58        var_re = re.compile(r'^d(\d+|Sta)_(\d+)\s+([\d.]+)$')
59
60        f = open(input_filename)
61        global path_counts
62        global total_edges
63        while True:
64                s = f.readline()
65                if s == '':
66                        break
67                g = var_re.match(s.strip())
68                if not g:
69                        continue
70
71                from_id, to_id, count = g.groups()
72                if from_id not in path_counts:
73                        path_counts[from_id] = {}
74
75                if to_id in path_counts[from_id]:
76                        raise Exception("Variable from %s to %s mentioned more than once." % (from_id, to_id))
77
78                count = int(round(float(count)))
79                if count == 0:
80                        continue
81
82                path_counts[from_id][to_id] = count
83                total_edges += count
84
85        f.close()
86        if not path_counts:
87                raise Exception("No variables found in solution.")
88
89
90
91def read_tcfg_map(input_filename):
92        tcfg_re = re.compile(
93                r'''^
94                        (\d+)
95                        \( \d+ \. \d+ \)
96                        \(
97                                (0x[0-9a-f]+)
98                                \+
99                                (0x[0-9a-f]+)
100                        \):
101                        \s
102                        \[\s
103                                ([\d\s]*)
104                        \]([0-9a-f\s]+)$''',
105                        re.VERBOSE)
106
107        f = open(input_filename)
108
109        global bb_addr_to_ids
110        global bb_count
111        global edge_count
112        global id_to_bb_addr
113        id_to_bb_addr = {}
114        while True:
115                s = f.readline()
116                if s == '':
117                        break
118                g = tcfg_re.match(s.strip())
119                if not g:
120                        continue
121
122                bb_id, bb_addr, bb_size, bb_dests, ctx_str= g.groups()
123                bb_id = int(bb_id)
124                #print ctx_str
125
126                bb_addr = int(bb_addr, 16)
127                #our context matching assumes that all bb addr has the same n_digits...
128                assert bb_addr >= 0xe0000000, 'Go fix context matching :( bb_addr: %x' % bb_addr
129                bb_size = int(bb_size, 16)
130                bb_dests = [int(x.strip()) for x in bb_dests.split() if x.strip()]
131                ctx_str_list = ctx_str.split(' ')
132                ctx_list = [int(x, 16) for x in ctx_str_list if x <> '']
133
134                if not bb_addr in bb_addr_to_ids:
135                        bb_addr_to_ids[bb_addr] = [bb_id]
136                else:
137                        bb_addr_to_ids[bb_addr].append(bb_id)
138
139                #if ctx_list[0] == 0:
140                #        ctx_list.pop(0)
141                loop_head = ctx_list[0]
142                if loop_head == 0:
143                  pass
144                else:
145                  if not bb_addr in bb_addrs_in_loops:
146                    bb_addrs_in_loops.append(bb_addr)
147                  pass
148                ctx_list=ctx_list[1:]
149
150                id_to_context[bb_id] = ctx_list
151                tcfg_paths[bb_id] = bb_dests
152                assert bb_id not in id_to_bb_addr
153                id_to_bb_addr[bb_id] = bb_addr
154                bb_count[bb_id] = 0
155                for dest in bb_dests:
156                        edge_count[(bb_id, dest)] = 0
157
158        f.close()
159
160def callSitePA(pn,p):
161        return phyAddrP(p.preds[pn][0],p)
162
163def callstring_bbs(fs,cs_so_far = None):
164          bbAddr= immFunc().bbAddr
165          #print 'fs: %s' % str(fs)
166          #print 'cs_so_far: %s' % str(cs_so_far)
167          if not cs_so_far:
168                cs_so_far = [ ]
169          ret = []
170          top_f = fs[0]
171          next_f = fs[1]
172          #print 'top_f #%s# next_f #%s#' % (top_f,next_f)
173          p = immFunc().f_problems[top_f]
174          cns = callNodes(p,fs=[next_f])
175          #print 'cns:%s'% cns
176          pA = lambda x: phyAddrP(x,p)
177          #phy_rets =[phyAddrP(x,p) for x in cns]
178          phy_rets =[callSitePA(x,p) for x in cns]
179          #print ' phy_rets: %s' % str(phy_rets)
180          if len(fs) == 2:
181                return  [(cs_so_far +[bbAddr(x)]) for x in phy_rets]
182
183          assert len(fs) >2
184          for x in phy_rets:
185                ret += callstring_bbs(fs[1:],cs_so_far = cs_so_far + [bbAddr(x)])
186          return ret
187
188#does context match the list of bb_heads s ?
189def bbs_context_match(context_targ, context_to_match):
190        bbAddr = immFunc().bbAddr
191        #print 'context: %s' % str(context)
192        #print 's: %s' % str(s)
193        #all of s must be in context to begin with, strip the 0 subfix
194        bb_ctx = [bbAddr(x) for x in context_targ[:-1]]
195        if len(bb_ctx) < len(context_to_match):
196           return False
197        for i in range(len(context_to_match)):
198           if context_to_match[i] != bb_ctx[len(context_to_match) -i - 1]:
199                 return False
200        return True
201
202def inALoop(addr):
203    '''
204    Is the addr part of a loop ?
205    '''
206    return bbAddr(addr) in bb_addrs_in_loops
207
208def inFunLoop(addr):
209  from addr_utils import gToPAddrP
210  f = elfFile().addrs_to_f[addr]
211  p = immFunc().f_problems[f]
212  p_addr = gToPAddrP(addr,p)
213  return p_addr in p.loop_data
214
215#ids for bb_addr with context
216def idsMatchingContext(full_bb_context):
217  ret = []
218  for i,addr in enumerate(full_bb_context):
219    ids = bb_addr_to_ids[addr]
220    #we know there can only be 1 matching id
221    match_id = None
222    for d in ids:
223      if bbs_context_match(id_to_context[d],full_bb_context[:i]):
224        match_id = d
225        break
226    assert match_id != None
227    ret.append(match_id)
228  return ret
229
230#translate context in bin to a list of ids corresponding to the context
231#note: no tcfg_ids can have the exact same context
232def contextsBbAddrToIds(context,bb_addr):
233  bb_context = [bbAddr(x[0]) for x in context]
234  ret = idsMatchingContext(context, bb_addr)
235  return ret
236
237def bbAddrs(l):
238  bbAddr = immFunc().bbAddr
239  return [bbAddr(x) for x in l if x]
240
241#get the unique tid corresponding to addr that matches context[: -truncate],
242#given that context is in the from of contexts from id_to_context
243def addr_context_truncate_get_tid(addr, context, truncate):
244  #note: r
245  if truncate:
246    #context = context[: - truncate]
247    context = context[truncate : ]
248  tids = [tid for tid in bb_addr_to_ids[addr]
249    if id_to_context[tid] == context]
250  if len(tids) > 1:
251    ctx = id_to_context[tids[0]]
252    print 'tids: %r' % tids
253    assert all ([id_to_context[x]==ctx for x in tids ])
254    #FIXME: some pairs of tids seem to share the exact same bb_addr, full ctx.
255    # seems to be related to tailcalls
256    print '!! WARNING !! Figure out wht\'s going on !! Seems to be a chronos feature...'
257    print 'tids[0]: %d' % tids[0]
258    return tids[0]
259  assert len(tids) == 1, (tids, addr, context, truncate)
260  [tid] = tids
261  return tid
262
263#emit a particular trace. full_context comes unmodified from id_to_context
264def emitInconsistentTrace(fout, full_context,visits,line=None):
265  if line==None:
266    line = ''
267  #fout.write('\ === impossible/inconsistent:' +line+' === \n')
268  print ' full_context: %r' % full_context
269  id_points = []
270  for x in visits:
271    addr,stack_i = x
272    print ' visit: %s, stack_i %d' % (addr,stack_i)
273    #addr = bbAddr(addr)
274    if stack_i == 0 and bbAddr(addr) == bbAddr(full_context[0]):
275      #special case for addr == tip_context, since id_to_context will be effectively one context short, as (addr / tip_context) itself will be missing.
276      id_points.append(addr_context_truncate_get_tid(addr, full_context, 1))
277    else:
278      id_points.append(addr_context_truncate_get_tid(addr, full_context, stack_i))
279
280  max_stack_i =max(visits, key = lambda x: x[1])[1]
281  print 'max_stack_i: %d' % max_stack_i
282  assert (max_stack_i < len(full_context))
283
284  limiting_node_bb = bbAddr(full_context[max_stack_i])
285  limiting_tid = addr_context_truncate_get_tid(limiting_node_bb, full_context,max_stack_i + 1)
286
287  #now emit the ilp constraint
288  print 'id_points: %r' % id_points
289  s = ' + '.join(['b{0}'.format(p) for p in id_points])
290  s += ' - {0} b{1} <= 0\n'.format( (len(id_points) -1  ), limiting_tid )
291  print 'emitting \n%s\n' % s
292  fout.write(s)
293
294'''
295def isHalt(addr):
296  text = elfFile().lines[addr]
297  if 'halt' in text:
298    return True
299  return False
300'''
301#context is infeasible
302def emitImpossible(fout,context):
303  #get the matching ids
304  tids = bb_addr_to_ids[bbAddr(context[-1])]
305  len_a = len(tids)
306  tids = [x for x in tids if bbs_context_match(id_to_context[x], bbAddrs(context[:-1])) ]
307  s=''
308  assert tids
309  for tid in tids:
310    s += 'b{0} = 0\n'.format(tid)
311  fout.write(s)
312  return
313
314#context addrs are bbAddrs
315#one_context is true iff all visits lie in context 0
316def emitInconsistent(fout, context,visits):
317  if visits == []:
318    #this is just an impossible trace
319    emitImpossible(fout,context)
320    return
321  #FIXME:
322  #bug : [4026621540, 4026621472] : [4026621428 <- 1] would become
323  #      [4026621540] : [4026621428<-0], which isn't the same thing...
324  '''
325  if all ([l > 0 for (_, l) in visits]):
326    m = min ([l for (_, l) in visits])
327    visits = [(baddr, l - m) for (baddr, l) in visits]
328    context = context[: -m]
329  '''
330  #padding visits with (context[-1],0) should work.
331  #can't think of a better way atm
332  if all ([l > 0 for (_, l) in visits]):
333    visits.append( (bbAddr(context[-1]),0) )
334
335  print 'context: %r' % context
336  print 'visits: %r' % visits
337
338  base_addr = [baddr for (baddr, l) in visits if l == 0][0]
339  tids = bb_addr_to_ids[bbAddr(base_addr)]
340
341  r_context = list(context)
342  r_context.reverse()
343
344  if bbAddr(base_addr) == bbAddr(context[-1]):
345    #special case...
346    for tid in tids:
347      tid_ctxt = id_to_context[tid]
348      if r_context and tid_ctxt[ : len(context) - 1] == r_context[1:]:
349        emitInconsistentTrace(fout,[context[-1]]+tid_ctxt,visits)
350  else:
351    for tid in tids:
352      tid_ctxt = id_to_context[tid]
353      if r_context and tid_ctxt[ :len(context) ] == r_context:
354        emitInconsistentTrace(fout,tid_ctxt,visits)
355
356  return
357
358def emit_f_conflicts (fout, line):
359  '''
360  '''
361  bbAddr = immFunc().bbAddr
362  match = re.search(r'\s*\[(?P<infeas>.*)\]\s*:\s*(?P<kind>.*$)', line)
363  print 'infeas: %s' % match.group('infeas')
364  infeasible_fs = match.group('infeas').split(',')
365  print 'infeasible_fs: %s' % infeasible_fs
366  #find all bb_addrs where f[-1] can be called by f[-2]
367  bbCallStrings = callstring_bbs(infeasible_fs)
368  print 'bbcs: %s' % str(bbCallStrings)
369  for s in bbCallStrings:
370   for x in s:
371      assert bbAddr(x) == x
372      final_callee_bb = s[-1]
373      #print 'final_callee_bb: %s'% str(final_callee_bb)
374      assert final_callee_bb in bb_addr_to_ids
375      ids = bb_addr_to_ids[final_callee_bb]
376      ids_ctx = [x for x in ids if bbs_context_match(id_to_context[x],s[:-1])]
377      print 'ids: %d, ids_ctx: %d' % (len(ids),len(ids_ctx))
378      #ok, now all of these are just unreachable.
379      for tcfg_id in ids_ctx:
380        fout.write("b{0} = 0\n".format(tcfg_id))
381
382def process_conflict(fout, conflict_files):
383    fake_preemption_points = []
384    for conflict_file in conflict_files:
385            f = open(conflict_file,'r')
386            global bb_addr_to_id
387            fout.write('\ === conflict constraints from %s === \n\n' % conflict_file)
388            last_bb_id = 0
389            bbAddr = immFunc().bbAddr
390            while True:
391                line = f.readline()
392                if line == '':
393                        break
394                #get rid of all white spaces
395                line = line.replace(' ', '')
396                line = line.rstrip()
397                if line.startswith('#') or line=='':
398                        #comment line
399                        continue
400                if line.rstrip() == '':
401                        continue
402                match = re.search(r'.*:\s*(?P<kind>.*$)', line)
403                kind = match.group('kind')
404                #print 'kind: %s' % kind
405                if kind == 'possible':
406                        continue
407                elif kind == 'f_conflicts':
408                        emit_f_conflicts (fout, line)
409                elif kind == 'phantom_preemp_point':
410                        match = re.search(r'\[(?P<addr>.*)\]:(?P<kind>.*$)', line)
411                        fake_preemption_points.append(int(match.group('addr'),16))
412                elif kind == 'times':
413                        match = re.search(r'\s*\[(?P<addr>.*)\]\s*:\s*\[(?P<times>.*)\]\s*:\s*(?P<kind>.*$)', line)
414                        addr = int(match.group('addr'),16)
415                        times = int(match.group('times'))
416                        print 'addr: %x' % addr
417                        print 'times: %d' % times
418                        times_limit(fout,[addr],times)
419                else:
420                    bits = line.split(':')
421                    [stack,visits,verdict] = bits
422                    assert 'impossible' in verdict
423                    stack = trace_refute.parse_num_list(stack)
424                    visits = trace_refute.parse_num_arrow_list(visits)
425                    bb_visits = [(bbAddr(x[0]),x[1]) for x in visits]
426                    in_loops = [x for x in (stack[1:] + [x[0] for x in visits]) if inFunLoop(x)]
427                    if in_loops:
428                        print '!!! loops in inconsistent !!!'
429                        print '%r' % in_loops
430                        print 'rejected line: %s\n\n' % line
431                        continue
432                    print 'line: %s' % line
433                    emitInconsistent(fout, stack,bb_visits)
434            f.close()
435    fout.write("\n");
436    return fake_preemption_points
437
438def add_impossible_contexts(fout):
439    fout.write('\ === excluded function constraints === \n\n')
440    for (tid, ctxt) in id_to_context.iteritems ():
441        if len (ctxt) <= 2:
442            continue
443        if not trace_refute.ctxt_within_function_limits (ctxt[:-2]):
444            fout.write ('b%s = 0\n' % tid)
445
446def print_constraints(conflict_files, old_cons_file, new_cons_file,sol_file_name, preempt_limit):
447        global bb_count
448        global edge_count
449        #copy the file
450        p = Popen(['cp', old_cons_file, new_cons_file])
451        p.communicate()
452        ret = p.returncode
453        assert not ret
454        #fin = open(old_cons_file)
455        #append the new constraints at the end
456        fout = open(new_cons_file,'a+')
457        #copy everything until we hit the General label
458        #cplex.endProblem(fout,log_file_name='./new-gcc-O2.imm.sol')
459        cplex.endProblem(fout,sol_file_name)
460        fout.write('add\n')
461        add_impossible_contexts(fout)
462        fake_preemption_points = process_conflict(fout,conflict_files)
463        print '%r' % fake_preemption_points
464        preemption_limit(fout,fake_preemption_points,preempt_limit)
465        fout.write('end\n')
466        cplex.solveProblem(fout,presolve_off=False)
467        fout.close()
468
469def id_print_context(id1,ctx=None):
470        if ctx==None:
471                ctx = id_to_context[id1][:-1]
472        for bb in ctx:
473          print '%s'% elfFile().addrs_to_f[bb]
474        return
475
476def preemption_limit(fout,fake_preemption_points,preempt_limit):
477        #hardcoded preemption point limit
478        fout.write('\ === preemption constraints === \n\n')
479        preemp_addr = elfFile().funcs['preemptionPoint'].addr
480        times_limit(fout,fake_preemption_points+[preemp_addr],preempt_limit)
481
482def times_limit(fout, addrs,limit):
483        ids = []
484        for addr in addrs:
485          ids += bb_addr_to_ids[addr]
486        fout.write('b{0} '.format(ids[0]))
487        for x in ids[1:]:
488          fout.write(' + b{0}'.format(x))
489        fout.write(' <= {0}\n'.format(limit))
490
491def conflict(entry_point_function, tcfg_map, conflict_files, old_ilp, new_ilp, dir_name, sol_file, emit_conflicts=False, do_cplex=False, interactive=False, silent_cplex=False, preempt_limit= None, default_phantom_preempt=False):
492        if preempt_limit == None:
493            preempt_limit = 5
494        if default_phantom_preempt:
495            conflict_files.append(convert_loop_bounds.phantomPreemptsAnnoFileName(dir_name))
496        #initialise graph_to_graph so we get immFunc
497        #load the loop_counts
498        print 'conflict.conflict: sol_file %s' % sol_file
499        bench.bench(dir_name, entry_point_function,False,True,False,parse_only=True )
500        #we need the loop data
501        immFunc().process()
502        global bbAddr
503        bbAddr = immFunc().bbAddr
504        read_tcfg_map(tcfg_map)
505        if interactive:
506          assert False, 'Halt'
507        if emit_conflicts:
508          print 'new_ilp:%s' % new_ilp
509          print_constraints(conflict_files, old_ilp, new_ilp, sol_file, preempt_limit)
510          if do_cplex:
511            cplex_ret = cplex.cplexSolve(new_ilp,silent=silent_cplex,sol_file=sol_file)
512            print 'cplex_ret: %s' % cplex_ret
513            return cplex_ret
514        #print_constraints(sys.argv[3], sys.argv[4], sys.argv[5])
515
516if __name__ == '__main__':
517        if len(sys.argv) != 9:
518                print '''Usage: python conflict.py [tcfg map] [conflict file] [ilp file with footer stripped] [new ilp file] [target_dir] [flag] [preemption limit] [sol file to be generated]
519                conflict file 1 and/or 2 can be empty
520                flag is one of:
521                        --c
522                                generate a new conflict file
523                        --i
524                                interactive mode, for debugging only
525                        --cx
526                                generate a new conflict file and call cplex directly'''
527                sys.exit(1)
528
529        tcfg_map = sys.argv[1]
530        old_ilp = sys.argv[3]
531        new_ilp = sys.argv[4]
532        dir_name = sys.argv[5]
533        flag = sys.argv[6]
534        preempt_limit = int(sys.argv[7])
535        print 'preempt_limit: %d' % preempt_limit
536        sol_file = sys.argv[8]
537        assert 'sol' in sol_file
538        conflict_files = [sys.argv[2]]
539        print 'sol_file: %s' % sol_file
540        print 'old_ilp %s' % old_ilp
541        print 'conflict_files: %r' % conflict_files
542        emit_conflicts = False
543        interactive = False
544        do_cplex=False
545        if flag == '--c':
546          emit_conflicts = True
547        if flag == '--cx':
548          emit_conflicts = True
549          do_cplex = True
550        elif flag == '--i':
551          interactive = True
552        #FIXME: un-hardcode the entry point function's name
553        ret = conflict('handleSyscall', tcfg_map,conflict_files,old_ilp,new_ilp,dir_name,sol_file,emit_conflicts=emit_conflicts,do_cplex=do_cplex,interactive=interactive,preempt_limit = preempt_limit, default_phantom_preempt=True)
554        print 'conflict terminated'
555        print 'ret: %s' % str(ret)
556
557