1#!/usr/bin/env python
2#
3#
4# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
5#
6# SPDX-License-Identifier: BSD-2-Clause
7#
8
9import os, re, sys, copy
10from subprocess import Popen, PIPE
11from elf_correlate import immFunc
12from elf_file import elfFile
13from addr_utils import callNodes,phyAddrP
14import bench
15import cplex
16import graph_refine.trace_refute as trace_refute
17import convert_loop_bounds
18from graph_refine.trace_refute import parse_num_arrow_list
19
20global bb_addr_to_ids
21global id_to_bb_addr
22global id_to_context
23global bb_count
24global edge_count
25global bb_addrs_in_loops
26global tcfg_paths
27bb_addr_to_ids = {}
28id_to_context = {}
29bb_count = {}
30edge_count = {}
31bb_addrs_in_loops = []
32tcfg_paths = {}
33id_to_bb_addr = {}
34
35def cleanGlobalStates():
36    global bb_addr_to_ids
37    global id_to_bb_addr
38    global id_to_context
39    global bb_count
40    global edge_count
41    global bb_addrs_in_loops
42    global tcfg_paths
43
44#we assume the ilp problem stays the same modulo extra refutes
45    bb_addr_to_ids = {}
46    id_to_context = {}
47    bb_count = {}
48    edge_count = {}
49    bb_addrs_in_loops = []
50    tcfg_paths = {}
51    id_to_bb_addr = {}
52    print 'conflict.py: global states cleaned'
53
54def read_variables(input_filename):
55        var_re = re.compile(r'^d(\d+|Sta)_(\d+)\s+([\d.]+)$')
56
57        f = open(input_filename)
58        global path_counts
59        global total_edges
60        while True:
61                s = f.readline()
62                if s == '':
63                        break
64                g = var_re.match(s.strip())
65                if not g:
66                        continue
67
68                from_id, to_id, count = g.groups()
69                if from_id not in path_counts:
70                        path_counts[from_id] = {}
71
72                if to_id in path_counts[from_id]:
73                        raise Exception("Variable from %s to %s mentioned more than once." % (from_id, to_id))
74
75                count = int(round(float(count)))
76                if count == 0:
77                        continue
78
79                path_counts[from_id][to_id] = count
80                total_edges += count
81
82        f.close()
83        if not path_counts:
84                raise Exception("No variables found in solution.")
85
86
87
88def read_tcfg_map(input_filename):
89        tcfg_re = re.compile(
90                r'''^
91                        (\d+)
92                        \( \d+ \. \d+ \)
93                        \(
94                                (0x[0-9a-f]+)
95                                \+
96                                (0x[0-9a-f]+)
97                        \):
98                        \s
99                        \[\s
100                                ([\d\s]*)
101                        \]([0-9a-f\s]+)$''',
102                        re.VERBOSE)
103
104        f = open(input_filename)
105
106        global bb_addr_to_ids
107        global bb_count
108        global edge_count
109        global id_to_bb_addr
110        id_to_bb_addr = {}
111        while True:
112                s = f.readline()
113                if s == '':
114                        break
115                g = tcfg_re.match(s.strip())
116                if not g:
117                        continue
118
119                bb_id, bb_addr, bb_size, bb_dests, ctx_str= g.groups()
120                bb_id = int(bb_id)
121                #print ctx_str
122
123                bb_addr = int(bb_addr, 16)
124                #our context matching assumes that all bb addr has the same n_digits...
125                assert bb_addr >= 0xe0000000, 'Go fix context matching :( bb_addr: %x' % bb_addr
126                bb_size = int(bb_size, 16)
127                bb_dests = [int(x.strip()) for x in bb_dests.split() if x.strip()]
128                ctx_str_list = ctx_str.split(' ')
129                ctx_list = [int(x, 16) for x in ctx_str_list if x <> '']
130
131                if not bb_addr in bb_addr_to_ids:
132                        bb_addr_to_ids[bb_addr] = [bb_id]
133                else:
134                        bb_addr_to_ids[bb_addr].append(bb_id)
135
136                #if ctx_list[0] == 0:
137                #        ctx_list.pop(0)
138                loop_head = ctx_list[0]
139                if loop_head == 0:
140                  pass
141                else:
142                  if not bb_addr in bb_addrs_in_loops:
143                    bb_addrs_in_loops.append(bb_addr)
144                  pass
145                ctx_list=ctx_list[1:]
146
147                id_to_context[bb_id] = ctx_list
148                tcfg_paths[bb_id] = bb_dests
149                assert bb_id not in id_to_bb_addr
150                id_to_bb_addr[bb_id] = bb_addr
151                bb_count[bb_id] = 0
152                for dest in bb_dests:
153                        edge_count[(bb_id, dest)] = 0
154
155        f.close()
156
157def callSitePA(pn,p):
158        return phyAddrP(p.preds[pn][0],p)
159
160def callstring_bbs(fs,cs_so_far = None):
161          bbAddr= immFunc().bbAddr
162          #print 'fs: %s' % str(fs)
163          #print 'cs_so_far: %s' % str(cs_so_far)
164          if not cs_so_far:
165                cs_so_far = [ ]
166          ret = []
167          top_f = fs[0]
168          next_f = fs[1]
169          #print 'top_f #%s# next_f #%s#' % (top_f,next_f)
170          p = immFunc().f_problems[top_f]
171          cns = callNodes(p,fs=[next_f])
172          #print 'cns:%s'% cns
173          pA = lambda x: phyAddrP(x,p)
174          #phy_rets =[phyAddrP(x,p) for x in cns]
175          phy_rets =[callSitePA(x,p) for x in cns]
176          #print ' phy_rets: %s' % str(phy_rets)
177          if len(fs) == 2:
178                return  [(cs_so_far +[bbAddr(x)]) for x in phy_rets]
179
180          assert len(fs) >2
181          for x in phy_rets:
182                ret += callstring_bbs(fs[1:],cs_so_far = cs_so_far + [bbAddr(x)])
183          return ret
184
185#does context match the list of bb_heads s ?
186def bbs_context_match(context_targ, context_to_match):
187        bbAddr = immFunc().bbAddr
188        #print 'context: %s' % str(context)
189        #print 's: %s' % str(s)
190        #all of s must be in context to begin with, strip the 0 subfix
191        bb_ctx = [bbAddr(x) for x in context_targ[:-1]]
192        if len(bb_ctx) < len(context_to_match):
193           return False
194        for i in range(len(context_to_match)):
195           if context_to_match[i] != bb_ctx[len(context_to_match) -i - 1]:
196                 return False
197        return True
198
199def inALoop(addr):
200    '''
201    Is the addr part of a loop ?
202    '''
203    return bbAddr(addr) in bb_addrs_in_loops
204
205def inFunLoop(addr):
206  from addr_utils import gToPAddrP
207  f = elfFile().addrs_to_f[addr]
208  p = immFunc().f_problems[f]
209  p_addr = gToPAddrP(addr,p)
210  return p_addr in p.loop_data
211
212#ids for bb_addr with context
213def idsMatchingContext(full_bb_context):
214  ret = []
215  for i,addr in enumerate(full_bb_context):
216    ids = bb_addr_to_ids[addr]
217    #we know there can only be 1 matching id
218    match_id = None
219    for d in ids:
220      if bbs_context_match(id_to_context[d],full_bb_context[:i]):
221        match_id = d
222        break
223    assert match_id != None
224    ret.append(match_id)
225  return ret
226
227#translate context in bin to a list of ids corresponding to the context
228#note: no tcfg_ids can have the exact same context
229def contextsBbAddrToIds(context,bb_addr):
230  bb_context = [bbAddr(x[0]) for x in context]
231  ret = idsMatchingContext(context, bb_addr)
232  return ret
233
234def bbAddrs(l):
235  bbAddr = immFunc().bbAddr
236  return [bbAddr(x) for x in l if x]
237
238#get the unique tid corresponding to addr that matches context[: -truncate],
239#given that context is in the from of contexts from id_to_context
240def addr_context_truncate_get_tid(addr, context, truncate):
241  #note: r
242  if truncate:
243    #context = context[: - truncate]
244    context = context[truncate : ]
245  tids = [tid for tid in bb_addr_to_ids[addr]
246    if id_to_context[tid] == context]
247  if len(tids) > 1:
248    ctx = id_to_context[tids[0]]
249    print 'tids: %r' % tids
250    assert all ([id_to_context[x]==ctx for x in tids ])
251    #FIXME: some pairs of tids seem to share the exact same bb_addr, full ctx.
252    # seems to be related to tailcalls
253    print '!! WARNING !! Figure out wht\'s going on !! Seems to be a chronos feature...'
254    print 'tids[0]: %d' % tids[0]
255    return tids[0]
256  assert len(tids) == 1, (tids, addr, context, truncate)
257  [tid] = tids
258  return tid
259
260#emit a particular trace. full_context comes unmodified from id_to_context
261def emitInconsistentTrace(fout, full_context,visits,line=None):
262  if line==None:
263    line = ''
264  #fout.write('\ === impossible/inconsistent:' +line+' === \n')
265  print ' full_context: %r' % full_context
266  id_points = []
267  for x in visits:
268    addr,stack_i = x
269    print ' visit: %s, stack_i %d' % (addr,stack_i)
270    #addr = bbAddr(addr)
271    if stack_i == 0 and bbAddr(addr) == bbAddr(full_context[0]):
272      #special case for addr == tip_context, since id_to_context will be effectively one context short, as (addr / tip_context) itself will be missing.
273      id_points.append(addr_context_truncate_get_tid(addr, full_context, 1))
274    else:
275      id_points.append(addr_context_truncate_get_tid(addr, full_context, stack_i))
276
277  max_stack_i =max(visits, key = lambda x: x[1])[1]
278  print 'max_stack_i: %d' % max_stack_i
279  assert (max_stack_i < len(full_context))
280
281  limiting_node_bb = bbAddr(full_context[max_stack_i])
282  limiting_tid = addr_context_truncate_get_tid(limiting_node_bb, full_context,max_stack_i + 1)
283
284  #now emit the ilp constraint
285  print 'id_points: %r' % id_points
286  s = ' + '.join(['b{0}'.format(p) for p in id_points])
287  s += ' - {0} b{1} <= 0\n'.format( (len(id_points) -1  ), limiting_tid )
288  print 'emitting \n%s\n' % s
289  fout.write(s)
290
291'''
292def isHalt(addr):
293  text = elfFile().lines[addr]
294  if 'halt' in text:
295    return True
296  return False
297'''
298#context is infeasible
299def emitImpossible(fout,context):
300  #get the matching ids
301  tids = bb_addr_to_ids[bbAddr(context[-1])]
302  len_a = len(tids)
303  tids = [x for x in tids if bbs_context_match(id_to_context[x], bbAddrs(context[:-1])) ]
304  s=''
305  assert tids
306  for tid in tids:
307    s += 'b{0} = 0\n'.format(tid)
308  fout.write(s)
309  return
310
311#context addrs are bbAddrs
312#one_context is true iff all visits lie in context 0
313def emitInconsistent(fout, context,visits):
314  if visits == []:
315    #this is just an impossible trace
316    emitImpossible(fout,context)
317    return
318  #FIXME:
319  #bug : [4026621540, 4026621472] : [4026621428 <- 1] would become
320  #      [4026621540] : [4026621428<-0], which isn't the same thing...
321  '''
322  if all ([l > 0 for (_, l) in visits]):
323    m = min ([l for (_, l) in visits])
324    visits = [(baddr, l - m) for (baddr, l) in visits]
325    context = context[: -m]
326  '''
327  #padding visits with (context[-1],0) should work.
328  #can't think of a better way atm
329  if all ([l > 0 for (_, l) in visits]):
330    visits.append( (bbAddr(context[-1]),0) )
331
332  print 'context: %r' % context
333  print 'visits: %r' % visits
334
335  base_addr = [baddr for (baddr, l) in visits if l == 0][0]
336  tids = bb_addr_to_ids[bbAddr(base_addr)]
337
338  r_context = list(context)
339  r_context.reverse()
340
341  if bbAddr(base_addr) == bbAddr(context[-1]):
342    #special case...
343    for tid in tids:
344      tid_ctxt = id_to_context[tid]
345      if r_context and tid_ctxt[ : len(context) - 1] == r_context[1:]:
346        emitInconsistentTrace(fout,[context[-1]]+tid_ctxt,visits)
347  else:
348    for tid in tids:
349      tid_ctxt = id_to_context[tid]
350      if r_context and tid_ctxt[ :len(context) ] == r_context:
351        emitInconsistentTrace(fout,tid_ctxt,visits)
352
353  return
354
355def emit_f_conflicts (fout, line):
356  '''
357  '''
358  bbAddr = immFunc().bbAddr
359  match = re.search(r'\s*\[(?P<infeas>.*)\]\s*:\s*(?P<kind>.*$)', line)
360  print 'infeas: %s' % match.group('infeas')
361  infeasible_fs = match.group('infeas').split(',')
362  print 'infeasible_fs: %s' % infeasible_fs
363  #find all bb_addrs where f[-1] can be called by f[-2]
364  bbCallStrings = callstring_bbs(infeasible_fs)
365  print 'bbcs: %s' % str(bbCallStrings)
366  for s in bbCallStrings:
367   for x in s:
368      assert bbAddr(x) == x
369      final_callee_bb = s[-1]
370      #print 'final_callee_bb: %s'% str(final_callee_bb)
371      assert final_callee_bb in bb_addr_to_ids
372      ids = bb_addr_to_ids[final_callee_bb]
373      ids_ctx = [x for x in ids if bbs_context_match(id_to_context[x],s[:-1])]
374      print 'ids: %d, ids_ctx: %d' % (len(ids),len(ids_ctx))
375      #ok, now all of these are just unreachable.
376      for tcfg_id in ids_ctx:
377        fout.write("b{0} = 0\n".format(tcfg_id))
378
379def process_conflict(fout, conflict_files):
380    fake_preemption_points = []
381    for conflict_file in conflict_files:
382            f = open(conflict_file,'r')
383            global bb_addr_to_id
384            fout.write('\ === conflict constraints from %s === \n\n' % conflict_file)
385            last_bb_id = 0
386            bbAddr = immFunc().bbAddr
387            while True:
388                line = f.readline()
389                if line == '':
390                        break
391                #get rid of all white spaces
392                line = line.replace(' ', '')
393                line = line.rstrip()
394                if line.startswith('#') or line=='':
395                        #comment line
396                        continue
397                if line.rstrip() == '':
398                        continue
399                match = re.search(r'.*:\s*(?P<kind>.*$)', line)
400                kind = match.group('kind')
401                #print 'kind: %s' % kind
402                if kind == 'possible':
403                        continue
404                elif kind == 'f_conflicts':
405                        emit_f_conflicts (fout, line)
406                elif kind == 'phantom_preemp_point':
407                        match = re.search(r'\[(?P<addr>.*)\]:(?P<kind>.*$)', line)
408                        fake_preemption_points.append(int(match.group('addr'),16))
409                elif kind == 'times':
410                        match = re.search(r'\s*\[(?P<addr>.*)\]\s*:\s*\[(?P<times>.*)\]\s*:\s*(?P<kind>.*$)', line)
411                        addr = int(match.group('addr'),16)
412                        times = int(match.group('times'))
413                        print 'addr: %x' % addr
414                        print 'times: %d' % times
415                        times_limit(fout,[addr],times)
416                else:
417                    bits = line.split(':')
418                    [stack,visits,verdict] = bits
419                    assert 'impossible' in verdict
420                    stack = trace_refute.parse_num_list(stack)
421                    visits = trace_refute.parse_num_arrow_list(visits)
422                    bb_visits = [(bbAddr(x[0]),x[1]) for x in visits]
423                    in_loops = [x for x in (stack[1:] + [x[0] for x in visits]) if inFunLoop(x)]
424                    if in_loops:
425                        print '!!! loops in inconsistent !!!'
426                        print '%r' % in_loops
427                        print 'rejected line: %s\n\n' % line
428                        continue
429                    print 'line: %s' % line
430                    emitInconsistent(fout, stack,bb_visits)
431            f.close()
432    fout.write("\n");
433    return fake_preemption_points
434
435def add_impossible_contexts(fout):
436    fout.write('\ === excluded function constraints === \n\n')
437    for (tid, ctxt) in id_to_context.iteritems ():
438        if len (ctxt) <= 2:
439            continue
440        if not trace_refute.ctxt_within_function_limits (ctxt[:-2]):
441            fout.write ('b%s = 0\n' % tid)
442
443def print_constraints(conflict_files, old_cons_file, new_cons_file,sol_file_name, preempt_limit):
444        global bb_count
445        global edge_count
446        #copy the file
447        p = Popen(['cp', old_cons_file, new_cons_file])
448        p.communicate()
449        ret = p.returncode
450        assert not ret
451        #fin = open(old_cons_file)
452        #append the new constraints at the end
453        fout = open(new_cons_file,'a+')
454        #copy everything until we hit the General label
455        #cplex.endProblem(fout,log_file_name='./new-gcc-O2.imm.sol')
456        cplex.endProblem(fout,sol_file_name)
457        fout.write('add\n')
458        add_impossible_contexts(fout)
459        fake_preemption_points = process_conflict(fout,conflict_files)
460        print '%r' % fake_preemption_points
461        preemption_limit(fout,fake_preemption_points,preempt_limit)
462        fout.write('end\n')
463        cplex.solveProblem(fout,presolve_off=False)
464        fout.close()
465
466def id_print_context(id1,ctx=None):
467        if ctx==None:
468                ctx = id_to_context[id1][:-1]
469        for bb in ctx:
470          print '%s'% elfFile().addrs_to_f[bb]
471        return
472
473def preemption_limit(fout,fake_preemption_points,preempt_limit):
474        #hardcoded preemption point limit
475        fout.write('\ === preemption constraints === \n\n')
476        preemp_addr = elfFile().funcs['preemptionPoint'].addr
477        times_limit(fout,fake_preemption_points+[preemp_addr],preempt_limit)
478
479def times_limit(fout, addrs,limit):
480        ids = []
481        for addr in addrs:
482          ids += bb_addr_to_ids[addr]
483        fout.write('b{0} '.format(ids[0]))
484        for x in ids[1:]:
485          fout.write(' + b{0}'.format(x))
486        fout.write(' <= {0}\n'.format(limit))
487
488def conflict(entry_point_function, tcfg_map, conflict_files, old_ilp, new_ilp, dir_name, sol_file, emit_conflicts=False, do_cplex=False, interactive=False, silent_cplex=False, preempt_limit= None, default_phantom_preempt=False):
489        if preempt_limit == None:
490            preempt_limit = 5
491        if default_phantom_preempt:
492            conflict_files.append(convert_loop_bounds.phantomPreemptsAnnoFileName(dir_name))
493        #initialise graph_to_graph so we get immFunc
494        #load the loop_counts
495        print 'conflict.conflict: sol_file %s' % sol_file
496        bench.bench(dir_name, entry_point_function,False,True,False,parse_only=True )
497        #we need the loop data
498        immFunc().process()
499        global bbAddr
500        bbAddr = immFunc().bbAddr
501        read_tcfg_map(tcfg_map)
502        if interactive:
503          assert False, 'Halt'
504        if emit_conflicts:
505          print 'new_ilp:%s' % new_ilp
506          print_constraints(conflict_files, old_ilp, new_ilp, sol_file, preempt_limit)
507          if do_cplex:
508            cplex_ret = cplex.cplexSolve(new_ilp,silent=silent_cplex,sol_file=sol_file)
509            print 'cplex_ret: %s' % cplex_ret
510            return cplex_ret
511        #print_constraints(sys.argv[3], sys.argv[4], sys.argv[5])
512
513if __name__ == '__main__':
514        if len(sys.argv) != 9:
515                print '''Usage: python conflict.py [tcfg map] [conflict file] [ilp file with footer stripped] [new ilp file] [target_dir] [flag] [preemption limit] [sol file to be generated]
516                conflict file 1 and/or 2 can be empty
517                flag is one of:
518                        --c
519                                generate a new conflict file
520                        --i
521                                interactive mode, for debugging only
522                        --cx
523                                generate a new conflict file and call cplex directly'''
524                sys.exit(1)
525
526        tcfg_map = sys.argv[1]
527        old_ilp = sys.argv[3]
528        new_ilp = sys.argv[4]
529        dir_name = sys.argv[5]
530        flag = sys.argv[6]
531        preempt_limit = int(sys.argv[7])
532        print 'preempt_limit: %d' % preempt_limit
533        sol_file = sys.argv[8]
534        assert 'sol' in sol_file
535        conflict_files = [sys.argv[2]]
536        print 'sol_file: %s' % sol_file
537        print 'old_ilp %s' % old_ilp
538        print 'conflict_files: %r' % conflict_files
539        emit_conflicts = False
540        interactive = False
541        do_cplex=False
542        if flag == '--c':
543          emit_conflicts = True
544        if flag == '--cx':
545          emit_conflicts = True
546          do_cplex = True
547        elif flag == '--i':
548          interactive = True
549        #FIXME: un-hardcode the entry point function's name
550        ret = conflict('handleSyscall', tcfg_map,conflict_files,old_ilp,new_ilp,dir_name,sol_file,emit_conflicts=emit_conflicts,do_cplex=do_cplex,interactive=interactive,preempt_limit = preempt_limit, default_phantom_preempt=True)
551        print 'conflict terminated'
552        print 'ret: %s' % str(ret)
553
554