1import check,search,problem,syntax,solver,logic,rep_graph,re
2from rep_graph import vc_num, vc_offs, pc_true_hyp, Hyp, eq_hyp
3from target_objects import functions, trace
4from check import restr_others,split_visit_one_visit
5from problem import inline_at_point
6from syntax import mk_not, true_term, false_term, mk_implies, Expr, Type, unspecified_precond_term,mk_and
7from rep_graph import mk_graph_slice, VisitCount, to_smt_expr
8from search import eval_model_expr
9import target_objects
10import trace_refute
11import stack_logic
12import time
13
14#tryFun must take exactly 1 argument
15def downBinSearch(minimum, maximum, tryFun):
16    upperBound = maximum
17    lowerBound = minimum
18    while upperBound > lowerBound:
19      print 'searching in %d - %d' % (lowerBound,upperBound)
20      cur = (lowerBound + upperBound) / 2
21      if tryFun(cur):
22          upperBound = cur
23      else:
24          lowerBound = cur + 1
25    assert upperBound == lowerBound
26    ret = lowerBound
27    return ret
28
29def upDownBinSearch (minimum, maximum, tryFun):
30    """performs a binary search between minimum and maximum, but does not start
31    in the middle. instead it does a binary escalation up from the minimum
32    first. this makes sense for ranges e.g. 2 - 1000000 where the bound is
33    likely to be near the bottom of the range. it also avoids testing values
34    more than twice as high as the bound, which may avoid some issues."""
35    upperBound = 2 * minimum
36    while upperBound < maximum:
37      if tryFun (upperBound):
38        return downBinSearch (minimum, upperBound, tryFun)
39      else:
40        upperBound *= 2
41    if tryFun (maximum):
42      return downBinSearch (minimum, maximum, tryFun)
43    else:
44      return None
45
46def addr_of_node (preds, n):
47  while not trace_refute.is_addr (n):
48    [n] = preds[n]
49  return n
50
51def all_asm_functions ():
52    ss = stack_logic.get_functions_with_tag ('ASM')
53    return [s for s in ss if not stack_logic.is_instruction (s)]
54
55call_site_set = {}
56
57def build_call_site_set ():
58    for f in all_asm_functions ():
59      preds = logic.compute_preds (functions[f].nodes)
60      for (n, node) in functions[f].nodes.iteritems ():
61        if node.kind == 'Call':
62          s = call_site_set.setdefault (node.fname, set ())
63          s.add (addr_of_node (preds, n))
64    call_site_set[('IsLoaded', None)] = True
65
66def all_call_sites (f):
67    if not call_site_set:
68      build_call_site_set ()
69    return list (call_site_set.get (f, []))
70
71#naive binary search to find loop bounds
72def findLoopBoundBS(p_n, p, restrs=None, hyps=None, try_seq=None):
73    if hyps == None:
74      hyps = []
75    #print 'restrs: %s' % str(restrs)
76    if try_seq == None:
77        #bound_try_seq = [1,2,3,4,5,10,50,130,200,260]
78        #bound_try_seq = [0,1,2,3,4,5,10,50,260]
79        calls = [n for n in p.loop_body (p_n) if p.nodes[n].kind == 'Call']
80        if calls:
81          bound_try_seq = [0,1,20]
82        else:
83          bound_try_seq = [0,1,20,34]
84    else:
85        bound_try_seq = try_seq
86    rep = mk_graph_slice (p, fast = True)
87    #get the head
88    #print 'Binary addr: %s' % toHexs(self.toPhyAddrs(p_loop_heads))
89    loop_bound = None
90    p_loop_heads = [n for n in p.loop_data if p.loop_data[n][0] == 'Head']
91    print 'p_loop_heads: %s' % p_loop_heads
92
93    if restrs == None:
94        others = [x for x in p_loop_heads if not x == p_n]
95        #vc_options([concrete numbers], [offsets])
96        restrs = tuple( [(n2, rep_graph.vc_options([0],[1])) for n2 in others] )
97
98    print 'getting the initial bound'
99    #try:
100    index = tryLoopBound(p_n,p,bound_try_seq,rep,restrs = restrs, hyps=hyps)
101    if index == -1:
102        return None
103    print 'got the initial bound %d' % bound_try_seq[index]
104
105    #do a downward binary search to find the concrete loop bound
106    if index == 0:
107      loop_bound = bound_try_seq[0]
108      print 'bound = %d' % loop_bound
109      return loop_bound
110    loop_bound = downBinSearch(bound_try_seq[index-1], bound_try_seq[index], lambda x: tryLoopBound(p_n,p,[x],rep,restrs=restrs, hyps=hyps, bin_return=True))
111    print 'bound = %d' % loop_bound
112    return loop_bound
113
114def default_n_vc_cases (p, n):
115  head = p.loop_id (n)
116  general = [(n2, rep_graph.vc_options ([0], [1]))
117  for n2 in p.loop_heads ()
118  if n2 != head]
119
120  if head:
121    return [(n, tuple (general + [(head, rep_graph.vc_num (1))])),
122         (n, tuple (general + [(head, rep_graph.vc_offs (1))]))]
123  specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head]
124  return [(n, tuple (general + specific))]
125
126def callNodes(p, fs= None):
127  ns = [n for n in p.nodes if p.nodes[n].kind == 'Call']
128  if fs != None:
129    ns = [n for n in ns if p.nodes[n].fname in fs]
130  return ns
131
132def noHaltHyps(split,p):
133    ret = []
134    all_halts = callNodes(p,fs=['halt'])
135    for x in all_halts:
136       ret += [rep_graph.pc_false_hyp((n_vc, p.node_tags[x][0]))
137         for n_vc in default_n_vc_cases (p, x)]
138    return ret
139
140def tryLoopBound(p_n, p, bounds,rep,restrs =None, hints =None,kind = 'Number',bin_return = False,hyps = None):
141    if restrs == None:
142        restrs = ()
143    if hints == None:
144        hints = []
145    if hyps == None:
146      hyps = []
147    tag = p.node_tags[p_n][0]
148    from stack_logic import default_n_vc
149    print 'trying bound: %s' % bounds
150    ret_bounds = []
151    for (index,i) in enumerate(bounds):
152            print 'testing %d' % i
153            restrs2 = restrs + ((p_n, VisitCount (kind, i)), )
154            try:
155                pc = rep.get_pc ((p_n, restrs2))
156            except:
157                print 'get_pc failed'
158                if bin_return:
159                    return False
160                else:
161                    return -1
162            #print 'got rep_.get_pc'
163            restrs3 = restr_others (p, restrs2, 2)
164            epc = rep.get_pc (('Err', restrs3), tag = tag)
165            hyp = mk_implies (mk_not (epc), mk_not (pc))
166            hyps = hyps + noHaltHyps(p_n,p)
167
168            #hyps = []
169            #print 'calling test_hyp_whyps'
170            if rep.test_hyp_whyps (hyp, hyps):
171                    print 'p_n %d: split limit found: %d' % (p_n, i)
172                    if bin_return:
173                      return True
174                    return index
175    if bin_return:
176      return False
177    print 'loop bound not found!'
178    return -1
179    assert False, 'failed to find loop bound for p_n %d' % p_n
180
181def get_linear_series_eqs (p, split, restrs, hyps, omit_standard = False):
182    k = ('linear_series_eqs', split, restrs, tuple (hyps))
183    if k in p.cached_analysis:
184        if omit_standard:
185          standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False))
186          return set (p.cached_analysis[k]) - standard
187        return p.cached_analysis[k]
188
189    cands = search.mk_seq_eqs (p, split, 1, with_rodata = False)
190    cands += candidate_additional_eqs (p, split)
191    (tag, _) = p.node_tags[split]
192
193    rep = rep_graph.mk_graph_slice (p, fast = True)
194
195    def do_checks (eqs_assume, eqs):
196        checks = (check.single_loop_induct_step_checks (p, restrs, hyps, tag,
197                split, 1, eqs, eqs_assume = eqs_assume)
198            + check.single_loop_induct_base_checks (p, restrs, hyps, tag,
199                split, 1, eqs))
200
201        groups = check.proof_check_groups (checks)
202        for group in groups:
203            (res, _) = check.test_hyp_group (rep, group)
204            if not res:
205                return False
206        return True
207
208    eqs = []
209    failed = []
210    while cands:
211        cand = cands.pop ()
212        if do_checks (eqs, [cand]):
213            eqs.append (cand)
214            failed.reverse ()
215            cands = failed + cands
216            failed = []
217        else:
218            failed.append (cand)
219
220    assert do_checks ([], eqs)
221    p.cached_analysis[k] = eqs
222    if omit_standard:
223      standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False))
224      return set (eqs) - standard
225    return eqs
226
227def get_linear_series_hyps (p, split, restrs, hyps):
228    eqs = get_linear_series_eqs (p, split, restrs, hyps)
229    (tag, _) = p.node_tags[split]
230    hyps = [h for (h, _) in linear_eq_hyps_at_visit (tag, split, eqs,
231             restrs, vc_offs (0))]
232    return hyps
233
234def is_zero (expr):
235    return expr.kind == 'Num' and expr.val & ((1 << expr.typ.num) - 1) == 0
236
237def candidate_additional_eqs (p, split):
238    eq_vals = set ()
239    def visitor (expr):
240      if expr.is_op ('Equals') and expr.vals[0].typ.kind == 'Word':
241        [x, y] = expr.vals
242        eq_vals.update ([(x, y), (y, x)])
243    for n in p.loop_body (split):
244      p.nodes[n].visit (lambda x: (), visitor)
245    for (x, y) in list (eq_vals):
246        if is_zero (x) and y.is_op ('Plus'):
247          [x, y] = y.vals
248          eq_vals.add ((x, syntax.mk_uminus (y)))
249          eq_vals.add ((y, syntax.mk_uminus (x)))
250        elif is_zero (x) and y.is_op ('Minus'):
251          [x, y] = y.vals
252          eq_vals.add ((x, y))
253          eq_vals.add ((y, x))
254
255    loop = syntax.mk_var ('%i', syntax.word32T)
256    minus_loop_step = syntax.mk_uminus (loop)
257
258    vas = search.get_loop_var_analysis_at(p, split)
259    ls_vas = dict ([(var, [data]) for (var, data) in vas
260        if data[0] == 'LoopLinearSeries'])
261    cmp_series = [(x, y, rew, offs) for (x, y) in eq_vals
262        for (_, rew, offs) in ls_vas.get (x, [])]
263    odd_eqs = []
264    for (x, y, rew, offs) in cmp_series:
265        x_init_cmp1 = syntax.mk_less_eq (x, rew (x, minus_loop_step))
266        x_init_cmp2 = syntax.mk_less_eq (rew (x, minus_loop_step), x)
267        fin_cmp1 = syntax.mk_less (x, y)
268        fin_cmp2 = syntax.mk_less (y, x)
269        odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp1))
270        odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp1))
271        odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp2))
272        odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp2))
273
274    ass_eqs = []
275    var_deps = p.compute_var_dependencies ()
276    for hook in target_objects.hooks ('extra_wcet_assertions'):
277        for assn in hook (var_deps[split]):
278            ass_eqs.append (assn)
279
280    return odd_eqs + ass_eqs
281
282extra_loop_consts = [2 ** 16]
283
284call_ctxt_problems = []
285
286avoid_C_information = [False]
287
288def get_call_ctxt_problem (split, call_ctxt, timing = True):
289    # time this for diagnostic reasons
290    start = time.time ()
291    from trace_refute import identify_function, build_compound_problem_with_links
292    f = identify_function (call_ctxt, [split])
293    for (ctxt2, p, hyps, addr_map) in call_ctxt_problems:
294      if ctxt2 == (call_ctxt, f):
295        return (p, hyps, addr_map)
296
297    (p, hyps, addr_map) = build_compound_problem_with_links (call_ctxt, f)
298    if avoid_C_information[0]:
299      hyps = [h for h in hyps if not has_C_information (p, h)]
300    call_ctxt_problems.append(((call_ctxt, f), p, hyps, addr_map))
301    del call_ctxt_problems[: -20]
302
303    end = time.time ()
304    if timing:
305      save_extra_timing ('GetProblem', call_ctxt + [split], end - start)
306
307    return (p, hyps, addr_map)
308
309def has_C_information (p, hyp):
310    for (n_vc, tag) in hyp.visits ():
311      if not p.hook_tag_hints.get (tag, None) == 'ASM':
312        return True
313
314known_bound_restr_hyps = {}
315
316known_bounds = {}
317
318def serialise_bound (addr, bound_info):
319    if bound_info == None:
320      return [hex(addr), "None", "None"]
321    else:
322      (bound, kind) = bound_info
323      assert logic.is_int (bound)
324      assert str (kind) == kind
325      return [hex (addr), str (bound), kind]
326
327def save_bound (glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound,
328        time = None):
329    f_names = [trace_refute.get_body_addrs_fun (x)
330      for x in call_ctxt + [split_bin_addr]]
331    loop_name = '<%s>' % ' -> '.join (f_names)
332    comment = '# bound for loop in %s:' % loop_name
333    ss = ['LoopBound'] + serialise_bound (split_bin_addr, bound)
334    if glob:
335      ss[0] = 'GlobalLoopBound'
336    ss += [str (len (call_ctxt))] + map (hex, call_ctxt)
337    ss += [str (prob_hash)]
338    if glob:
339      assert prev_bounds == None
340    else:
341      ss += [str (len (prev_bounds))]
342      for (split, bound) in prev_bounds:
343        ss += serialise_bound (split, bound)
344    s = ' '.join (ss)
345    f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a')
346    f.write (comment + '\n')
347    f.write (s + '\n')
348    if time != None:
349      ctxt2 = call_ctxt + [split_bin_addr]
350      ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2))
351      f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time))
352    f.close ()
353    trace ('Found bound %s for 0x%x in %s.' % (bound, split_bin_addr,
354      loop_name))
355
356def save_extra_timing (nm, ctxt, time):
357    ss = ['ExtraTiming', nm, str (len (ctxt))] + map (hex, ctxt) + [str(time)]
358    f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a')
359    f.write (' '.join (ss) + '\n')
360    f.close ()
361
362def parse_bound (ss, n):
363    addr = syntax.parse_int (ss[n])
364    bound = ss[n + 1]
365    if bound == 'None':
366      bound = None
367      return (n + 3, (addr, None))
368    else:
369      bound = syntax.parse_int (bound)
370      kind = ss[n + 2]
371      return (n + 3, (addr, (bound, kind)))
372
373def parse_ctxt_id (bits, n):
374    return (n + 1, syntax.parse_int (bits[n]))
375
376def parse_ctxt (bits, n):
377    return syntax.parse_list (parse_ctxt_id, bits, n)
378
379def load_bounds ():
380    try:
381      f = open ('%s/LoopBounds.txt' % target_objects.target_dir)
382      ls = list (f)
383      f.close ()
384    except IOError, e:
385      ls = []
386    from syntax import parse_int, parse_list
387    for l in ls:
388      bits = l.split ()
389      if bits[:1] not in [['LoopBound'], ['GlobalLoopBound']]:
390        continue
391      (n, (addr, bound)) = parse_bound (bits, 1)
392      (n, ctxt) = parse_ctxt (bits, n)
393      prob_hash = parse_int (bits[n])
394      n += 1
395      if bits[0] == 'LoopBound':
396        (n, prev_bounds) = parse_list (parse_bound, bits, n)
397        assert n == len (bits), bits
398        known = known_bounds.setdefault (addr, [])
399        known.append ((ctxt, prob_hash, prev_bounds, bound))
400      else:
401        assert n == len (bits), bits
402        known = known_bounds.setdefault ((addr, 'Global'), [])
403        known.append ((ctxt, prob_hash, bound))
404    known_bounds['Loaded'] = True
405
406def get_bound_ctxt (split, call_ctxt, use_cache = True):
407    trace ('Getting bound for 0x%x in context %s.' % (split, call_ctxt))
408    (p, hyps, addr_map) = get_call_ctxt_problem (split, call_ctxt)
409
410    orig_split = split
411    split = p.loop_id (addr_map[split])
412    assert split, (orig_split, call_ctxt)
413    split_bin_addr = min ([addr for addr in addr_map
414        if p.loop_id (addr_map[addr]) == split])
415
416    prior = get_prior_loop_heads (p, split)
417    restrs = ()
418    prev_bounds = []
419    for split2 in prior:
420      # recursion!
421      split2 = p.loop_id (split2)
422      assert split2
423      addr = min ([addr for addr in addr_map
424        if p.loop_id (addr_map[addr]) == split2])
425      bound = get_bound_ctxt (addr, call_ctxt)
426      prev_bounds.append ((addr, bound))
427      k = (p.name, split2, bound, restrs, tuple (hyps))
428      if k in known_bound_restr_hyps:
429        (restrs, hyps) = known_bound_restr_hyps[k]
430      else:
431        (restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps,
432          split2, bound, call_ctxt + [orig_split])
433        known_bound_restr_hyps[k] = (restrs, hyps)
434
435    # start timing now. we miss some setup time, but it avoids double counting
436    # the recursive searches.
437    start = time.time ()
438
439    p_h = problem_hash (p)
440    prev_bounds = sorted (prev_bounds)
441    if not known_bounds:
442      load_bounds ()
443    known = known_bounds.get (split_bin_addr, [])
444    for (call_ctxt2, h, prev_bounds2, bound) in known:
445      match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2)
446      if match and use_cache and h == p_h and prev_bounds2 == prev_bounds:
447        return bound
448    bound = search_bin_bound (p, restrs, hyps, split)
449    known = known_bounds.setdefault (split_bin_addr, [])
450    known.append ((call_ctxt, p_h, prev_bounds, bound))
451    end = time.time ()
452    save_bound (False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound,
453        time = end - start)
454    return bound
455
456def problem_hash (p):
457    return syntax.hash_tuplify ([p.name, p.entries,
458      sorted (p.outputs.iteritems ()), sorted (p.nodes.iteritems ())])
459
460def search_bin_bound (p, restrs, hyps, split):
461    trace ('Searching for bound for 0x%x in %s.', (split, p.name))
462    bound = search_bound (p, restrs, hyps, split)
463    if bound:
464      return bound
465
466    # try to use a bound inferred from C
467    if avoid_C_information[0]:
468      # OK told not to
469      return None
470    if get_prior_loop_heads (p, split):
471      # too difficult for now
472      return None
473    asm_tag = p.node_tags[split][0]
474    (_, fname, _) = p.get_entry_details (asm_tag)
475    funs = [f for pair in target_objects.pairings[fname]
476      for f in pair.funs.values ()]
477    c_tags = [tag for tag in p.tags ()
478        if p.get_entry_details (tag)[1] in funs and tag != asm_tag]
479    if len (c_tags) != 1:
480      print 'Surprised to see multiple matching tags %s' % c_tags
481      return None
482
483    [c_tag] = c_tags
484
485    rep = rep_graph.mk_graph_slice (p)
486    if len (search.get_loop_entry_sites (rep, restrs, hyps, split)) != 1:
487      # technical, but it's not going to work in this case
488      return None
489
490    return getBinaryBoundFromC (p, c_tag, split, restrs, hyps)
491
492def rab_test ():
493    [split_bin_addr] = get_loop_heads (functions['resolveAddressBits'])
494    (p, hyps, addr_map) = get_call_ctxt_problem (split_bin_addr, [])
495    split = p.loop_id (addr_map[split_bin_addr])
496    [c_tag] = [tag for tag in p.tags () if tag != p.node_tags[split][0]]
497    return getBinaryBoundFromC (p, c_tag, split, (), hyps)
498
499last_search_bound = [0]
500
501def search_bound (p, restrs, hyps, split):
502    last_search_bound[0] = (p, restrs, hyps, split)
503
504    # try a naive bin search first
505    # limit this to a small bound for time purposes
506    #   - for larger bounds the less naive approach can be faster
507    bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps,
508      try_seq = [0, 1, 6])
509    if bound != None:
510      return (bound, 'NaiveBinSearch')
511
512    l_hyps = get_linear_series_hyps (p, split, restrs, hyps)
513
514    rep = rep_graph.mk_graph_slice (p, fast = True)
515
516    def test (n):
517        assert n > 10
518        hyp = check.mk_loop_counter_eq_hyp (p, split, restrs, n - 2)
519        visit = ((split, vc_offs (2)), ) + restrs
520        continue_to_split_guess = rep.get_pc ((split, visit))
521        return rep.test_hyp_whyps (syntax.mk_not (continue_to_split_guess),
522          [hyp] + l_hyps + hyps)
523
524    # findLoopBoundBS always checks to at least 16
525    min_bound = 16
526    max_bound = max_acceptable_bound[0]
527    bound = upDownBinSearch (min_bound, max_bound, test)
528    if bound != None and test (bound):
529      return (bound, 'InductiveBinSearch')
530
531    # let the naive bin search go a bit further
532    bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps)
533    if bound != None:
534      return (bound, 'NaiveBinSearch')
535
536    return None
537
538def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps):
539    c_heads = [h for h in search.init_loops_to_split (p, restrs)
540      if p.node_tags[h][0] == c_tag]
541    c_bounds = [(p.loop_id (split), search_bound (p, (), hyps, split))
542      for split in c_heads]
543    if not [b for (n, b) in c_bounds if b]:
544      trace ('no C bounds found (%s).' % c_bounds)
545      return None
546
547    asm_tag = p.node_tags[asm_split][0]
548
549    rep = rep_graph.mk_graph_slice (p)
550    i_seq_opts = [(0, 1), (1, 1), (2, 1)]
551    j_seq_opts = [(0, 1), (0, 2), (1, 1)]
552    tags = [p.node_tags[asm_split][0], c_tag]
553    try:
554      split = search.find_split (rep, asm_split, restrs, hyps, i_seq_opts,
555        j_seq_opts, 5, tags = [asm_tag, c_tag])
556    except solver.SolverFailure, e:
557      return None
558    if not split or split[0] != 'Split':
559      trace ('no split found (%s).' % repr (split))
560      return None
561    (_, split) = split
562    rep = rep_graph.mk_graph_slice (p)
563    checks = check.split_checks (p, (), hyps, split, tags = [asm_tag, c_tag])
564    groups = check.proof_check_groups (checks)
565    try:
566      for group in groups:
567        (res, el) = check.test_hyp_group (rep, group)
568        if not res:
569          trace ('split check failed!')
570          trace ('failed at %s' % el)
571          return None
572    except solver.SolverFailure, e:
573      return None
574    (as_details, c_details, _, n, _) = split
575    (c_split, (seq_start, step), _) = c_details
576    c_bound = dict (c_bounds).get (p.loop_id (c_split))
577    if not c_bound:
578      trace ('key split was not bounded (%r, %r).' % (c_split, c_bounds))
579      return None
580    (c_bound, _) = c_bound
581    max_it = (c_bound - seq_start) / step
582    assert max_it > n, (max_it, n)
583    (_, (seq_start, step), _) = as_details
584    as_bound = seq_start + (max_it * step)
585    # increment by 1 as this may be a bound for a different splitting point
586    # which occurs later in the loop
587    as_bound += 1
588    return (as_bound, 'FromC')
589
590def get_prior_loop_heads (p, split, use_rep = None):
591    if use_rep:
592      rep = use_rep
593    else:
594      rep = rep_graph.mk_graph_slice (p)
595    prior = []
596    split = p.loop_id (split)
597    for h in p.loop_heads ():
598      s = set (prior)
599      if h not in s and rep.get_reachable (h, split) and h != split:
600        # need to recurse to ensure prior are in order
601        prior2 = get_prior_loop_heads (p, h, use_rep = rep)
602        prior.extend ([h2 for h2 in prior2 if h2 not in s])
603        prior.append (h)
604    return prior
605
606def add_loop_bound_restrs_hyps (p, restrs, hyps, split, bound, ctxt):
607    # time this for diagnostic reasons
608    start = time.time ()
609
610    #vc_options([concrete numbers], [offsets])
611    hyps = hyps + get_linear_series_hyps (p, split, restrs, hyps)
612    hyps = list (set (hyps))
613    if bound == None or bound >= 10:
614        restrs = restrs + ((split, rep_graph.vc_options([0],[1])),)
615    else:
616        restrs = restrs + ((split, rep_graph.vc_upto (bound+1)),)
617
618    end = time.time ()
619    save_extra_timing ('LoopBoundRestrHyps', ctxt, end - start)
620
621    return (restrs, hyps)
622
623max_acceptable_bound = [1000000]
624
625functions_hash = [None]
626
627def get_functions_hash ():
628    if functions_hash[0] != None:
629      return functions_hash[0]
630    h = hash (tuple (sorted ([(f, hash (functions[f])) for f in functions])))
631    functions_hash[0] = h
632    return h
633
634addr_to_loop_id_cache = {}
635complex_loop_id_cache = {}
636
637def addr_to_loop_id (split):
638    if split not in addr_to_loop_id_cache:
639      add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split))
640    return addr_to_loop_id_cache[split]
641
642def is_complex_loop (split):
643    split = addr_to_loop_id (split)
644    if split not in complex_loop_id_cache:
645      add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split))
646    return complex_loop_id_cache[split]
647
648def get_loop_addrs (split):
649    split = addr_to_loop_id (split)
650    f = functions[trace_refute.get_body_addrs_fun (split)]
651    return [addr for addr in f.nodes if trace_refute.is_addr (addr)
652        if addr_to_loop_id_cache.get (addr) == split]
653
654def add_fun_to_loop_data_cache (fname):
655    p = functions[fname].as_problem (problem.Problem)
656    p.do_loop_analysis ()
657    for h in p.loop_heads ():
658      addrs = [n for n in p.loop_body (h)
659        if trace_refute.is_addr (n)]
660      min_addr = min (addrs)
661      for addr in addrs:
662        addr_to_loop_id_cache[addr] = min_addr
663      complex_loop_id_cache[min_addr] = problem.has_inner_loop (p, h)
664    return min_addr
665
666def get_bound_super_ctxt (split, call_ctxt, no_splitting=False,
667        known_bound_only=False):
668    if not known_bounds:
669      load_bounds ()
670    for (ctxt2, fn_hash, bound) in known_bounds.get ((split, 'Global'), []):
671      if ctxt2 == call_ctxt and fn_hash == get_functions_hash ():
672        return bound
673    min_loop_addr = addr_to_loop_id (split)
674    if min_loop_addr != split:
675      return get_bound_super_ctxt (min_loop_addr, call_ctxt,
676            no_splitting = no_splitting, known_bound_only = known_bound_only)
677
678    if known_bound_only:
679        return None
680    no_splitting_abort = [False]
681    try:
682      bound = get_bound_super_ctxt_inner (split, call_ctxt,
683        no_splitting = (no_splitting, no_splitting_abort))
684    except problem.Abort, e:
685      bound = None
686    if no_splitting_abort[0]:
687      # don't record this bound, since it might change if splitting was allowed
688      return bound
689    known = known_bounds.setdefault ((split, 'Global'), [])
690    known.append ((call_ctxt, get_functions_hash (), bound))
691    save_bound (True, split, call_ctxt, get_functions_hash (), None, bound)
692    return bound
693
694from trace_refute import (function_limit, ctxt_within_function_limits)
695
696def call_ctxt_computable (split, call_ctxt):
697    fs = [trace_refute.identify_function ([], [call_site])
698      for call_site in call_ctxt]
699    non_computable = [f for f in fs if trace_refute.has_complex_loop (f)]
700    if non_computable:
701      trace ('avoiding functions with complex loops: %s' % non_computable)
702    return not non_computable
703
704def get_bound_super_ctxt_inner (split, call_ctxt,
705      no_splitting = (False, None)):
706    first_f = trace_refute.identify_function ([], (call_ctxt + [split])[:1])
707    call_sites = all_call_sites (first_f)
708
709    if function_limit (first_f) == 0:
710      return (0, 'FunctionLimit')
711    safe_call_sites = [cs for cs in call_sites
712      if ctxt_within_function_limits ([cs] + call_ctxt)]
713    if call_sites and not safe_call_sites:
714      return (0, 'FunctionLimit')
715
716    if len (call_ctxt) < 3 and len (safe_call_sites) == 1:
717      call_ctxt2 = list (safe_call_sites) + call_ctxt
718      if call_ctxt_computable (split, call_ctxt2):
719        trace ('using unique calling context %s' % str ((split, call_ctxt2)))
720        return get_bound_super_ctxt (split, call_ctxt2)
721
722    fname = trace_refute.identify_function (call_ctxt, [split])
723    bound = function_limit_bound (fname, split)
724    if bound:
725      return bound
726
727    bound = get_bound_ctxt (split, call_ctxt)
728    if bound:
729      return bound
730
731    trace ('no bound found immediately.')
732
733    if no_splitting[0]:
734      assert no_splitting[1], no_splitting
735      no_splitting[1][0] = True
736      trace ('cannot split by context (recursion).')
737      return None
738
739    # try to split over potential call sites
740    if len (call_ctxt) >= 3:
741      trace ('cannot split by context (context depth).')
742      return None
743
744    if len (call_sites) == 0:
745      # either entry point or nonsense
746      trace ('cannot split by context (reached top level).')
747      return None
748
749    problem_sites = [call_site for call_site in safe_call_sites
750        if not call_ctxt_computable (split, [call_site] + call_ctxt)]
751    if problem_sites:
752      trace ('cannot split by context (issues in %s).' % problem_sites)
753      return None
754
755    anc_bounds = [get_bound_super_ctxt (split, [call_site] + call_ctxt,
756        no_splitting = True)
757      for call_site in safe_call_sites]
758    if None in anc_bounds:
759      return None
760    (bound, kind) = max (anc_bounds)
761    return (bound, 'MergedBound')
762
763def function_limit_bound (fname, split):
764    p = functions[fname].as_problem (problem.Problem)
765    p.do_analysis ()
766    cuts = [n for n in p.loop_body (split)
767        if p.nodes[n].kind == 'Call'
768        if function_limit (p.nodes[n].fname) != None]
769    if not cuts:
770      return None
771    graph = p.mk_node_graph (p.loop_body (split))
772    # it is not possible to iterate the loop without visiting a bounded
773    # function. naively, this sets the limit to the sum of all the possible
774    # bounds, plus one because we can enter the loop a final time without
775    # visiting any function call site yet.
776    if logic.divides_loop (graph, set (cuts)):
777      fnames = set ([p.nodes[n].fname for n in cuts])
778      return (sum ([function_limit (f) for f in fnames]) + 1, 'FunctionLimit')
779
780def loop_bound_difficulty_estimates (split, ctxt):
781    # various guesses at how hard the loop bounding problem is.
782    (p, hyps, addr_map) = get_call_ctxt_problem (split, ctxt, timing = False)
783
784    loop_id = p.loop_id (addr_map[split])
785    assert loop_id
786
787    # number of instructions in the loop
788    inst_node_ids = set (addr_map.itervalues ())
789    l_insts = [n for n in p.loop_body (loop_id) if n in inst_node_ids]
790
791    # number of instructions in the function
792    tag = p.node_tags[loop_id][0]
793    f_insts = [n for n in inst_node_ids if p.node_tags[n][0] == tag]
794
795    # number of instructions in the whole calling context
796    ctxt_insts = len (inst_node_ids)
797
798    # well, what else?
799    return (len (l_insts), len (f_insts), ctxt_insts)
800
801def load_timing ():
802    f = open ('%s/LoopBounds.txt' % target_objects.target_dir)
803    timing = {}
804    loop_time = 0.0
805    ext_time = 0.0
806    for line in f:
807      bits = line.split ()
808      if not (bits and 'Timing' in bits[0]):
809        continue
810      if bits[0] == 'LoopBoundTiming':
811        (n, ext_ctxt) = parse_ctxt (bits, 1)
812        assert n == len (bits) - 1
813        time = float (bits[n])
814        ctxt = ext_ctxt[:-1]
815        split = ext_ctxt[-1]
816        timing[(split, tuple(ctxt))] = time
817        loop_time += time
818      elif bits[0] == 'ExtraTiming':
819        time = float (bits[-1])
820        ext_time += time
821    f.close ()
822    f = open ('%s/time' % target_objects.target_dir)
823    [l] = [l for l in f if '(wall clock)' in l]
824    f.close ()
825    tot_time_str = l.split ()[-1]
826    tot_time = sum ([float(s) * (60 ** i)
827      for (i, s) in enumerate (reversed (tot_time_str.split(':')))])
828
829    return (loop_time, ext_time, tot_time, timing)
830
831def mk_timing_metrics ():
832    if not known_bounds:
833      load_bounds ()
834    probs = [(split_bin_addr, tuple (call_ctxt), bound)
835      for (split_bin_addr, known) in known_bounds.iteritems ()
836      if type (split_bin_addr) == int
837      for (call_ctxt, h, prev_bounds, bound) in known]
838    probs = set (probs)
839    data = [(split, ctxt, bound,
840        loop_bound_difficulty_estimates (split, list (ctxt)))
841      for (split, ctxt, bound) in probs]
842    return data
843
844# sigh, this is so much work.
845bound_kind_nums = {
846  'FunctionLimit': 2,
847  'NaiveBinSearch': 3,
848  'InductiveBinSearch': 4,
849  'FromC': 5,
850  'MergedBound': 6,
851}
852
853gnuplot_colours = [
854  "dark-red", "dark-blue", "dark-green", "dark-grey",
855  "dark-orange", "dark-magenta", "dark-cyan"]
856
857def save_timing_metrics (num):
858    (loop_time, ext_time, tot_time, timing) = load_timing ()
859
860    col = gnuplot_colours[num]
861    from target import short_name
862
863    time_ests = mk_timing_metrics ()
864    import os
865    f = open ('%s/LoopTimingMetrics.txt' % target_objects.target_dir, 'w')
866    f.write ('"%s"\n' % short_name)
867
868    for (split, ctxt, bound, ests) in time_ests:
869      time = timing[(split, tuple (ctxt))]
870      if bound == None:
871        bdata = "1000000 7"
872      else:
873        bdata = '%d %d' % (bound[0], bound_kind_nums[bound[1]])
874      (l_i, f_i, ct_i) = ests
875      f.write ('%s %s %s %s %s %r %s\n' % (short_name, l_i, f_i, ct_i,
876        bdata, col, time))
877    f.close ()
878
879def get_loop_heads (fun):
880    if not fun.entry:
881      return []
882    p = fun.as_problem (problem.Problem)
883    p.do_loop_analysis ()
884    loops = set ()
885    for h in p.loop_heads ():
886      # any address in the loop will do. pick the smallest one
887      addr = min ([n for n in p.loop_body (h) if trace_refute.is_addr (n)])
888      loops.add ((addr, fun.name, problem.has_inner_loop (p, h)))
889    return list (loops)
890
891def get_all_loop_heads ():
892    loops = set ()
893    abort_funs = set ()
894    for f in all_asm_functions ():
895      try:
896        loops.update (get_loop_heads (functions[f]))
897      except problem.Abort, e:
898        abort_funs.add (f)
899    if abort_funs:
900      trace ('Cannot analyse loops in: %s' % ', '.join (abort_funs))
901    return loops
902
903def get_complex_loops ():
904    return [(loop, name) for (loop, name, compl) in get_all_loop_heads ()
905        if compl]
906
907def search_all_loops ():
908    all_loops = get_all_loop_heads ()
909    for (loop, _, _) in all_loops:
910      get_bound_super_ctxt (loop, [])
911
912main = search_all_loops
913
914if __name__ == '__main__':
915    import sys
916    args = target_objects.load_target_args ()
917    if args == ['search']:
918      search_all_loops ()
919    elif args[:1] == ['metrics']:
920      num = args[1:].index (str (target_objects.target_dir))
921      save_timing_metrics (num)
922
923
924