1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7import check,search,problem,syntax,solver,logic,rep_graph,re
8from rep_graph import vc_num, vc_offs, pc_true_hyp, Hyp, eq_hyp
9from target_objects import functions, trace
10from check import restr_others,split_visit_one_visit
11from problem import inline_at_point
12from syntax import mk_not, true_term, false_term, mk_implies, Expr, Type, unspecified_precond_term,mk_and
13from rep_graph import mk_graph_slice, VisitCount, to_smt_expr
14from search import eval_model_expr
15import target_objects
16import trace_refute
17import stack_logic
18import time
19
20#tryFun must take exactly 1 argument
21def downBinSearch(minimum, maximum, tryFun):
22    upperBound = maximum
23    lowerBound = minimum
24    while upperBound > lowerBound:
25      print 'searching in %d - %d' % (lowerBound,upperBound)
26      cur = (lowerBound + upperBound) / 2
27      if tryFun(cur):
28          upperBound = cur
29      else:
30          lowerBound = cur + 1
31    assert upperBound == lowerBound
32    ret = lowerBound
33    return ret
34
35def upDownBinSearch (minimum, maximum, tryFun):
36    """performs a binary search between minimum and maximum, but does not start
37    in the middle. instead it does a binary escalation up from the minimum
38    first. this makes sense for ranges e.g. 2 - 1000000 where the bound is
39    likely to be near the bottom of the range. it also avoids testing values
40    more than twice as high as the bound, which may avoid some issues."""
41    upperBound = 2 * minimum
42    while upperBound < maximum:
43      if tryFun (upperBound):
44        return downBinSearch (minimum, upperBound, tryFun)
45      else:
46        upperBound *= 2
47    if tryFun (maximum):
48      return downBinSearch (minimum, maximum, tryFun)
49    else:
50      return None
51
52def addr_of_node (preds, n):
53  while not trace_refute.is_addr (n):
54    [n] = preds[n]
55  return n
56
57def all_asm_functions ():
58    ss = stack_logic.get_functions_with_tag ('ASM')
59    return [s for s in ss if not stack_logic.is_instruction (s)]
60
61call_site_set = {}
62
63def build_call_site_set ():
64    for f in all_asm_functions ():
65      preds = logic.compute_preds (functions[f].nodes)
66      for (n, node) in functions[f].nodes.iteritems ():
67        if node.kind == 'Call':
68          s = call_site_set.setdefault (node.fname, set ())
69          s.add (addr_of_node (preds, n))
70    call_site_set[('IsLoaded', None)] = True
71
72def all_call_sites (f):
73    if not call_site_set:
74      build_call_site_set ()
75    return list (call_site_set.get (f, []))
76
77#naive binary search to find loop bounds
78def findLoopBoundBS(p_n, p, restrs=None, hyps=None, try_seq=None):
79    if hyps == None:
80      hyps = []
81    #print 'restrs: %s' % str(restrs)
82    if try_seq == None:
83        #bound_try_seq = [1,2,3,4,5,10,50,130,200,260]
84        #bound_try_seq = [0,1,2,3,4,5,10,50,260]
85        calls = [n for n in p.loop_body (p_n) if p.nodes[n].kind == 'Call']
86        if calls:
87          bound_try_seq = [0,1,20]
88        else:
89          bound_try_seq = [0,1,20,34]
90    else:
91        bound_try_seq = try_seq
92    rep = mk_graph_slice (p, fast = True)
93    #get the head
94    #print 'Binary addr: %s' % toHexs(self.toPhyAddrs(p_loop_heads))
95    loop_bound = None
96    p_loop_heads = [n for n in p.loop_data if p.loop_data[n][0] == 'Head']
97    print 'p_loop_heads: %s' % p_loop_heads
98
99    if restrs == None:
100        others = [x for x in p_loop_heads if not x == p_n]
101        #vc_options([concrete numbers], [offsets])
102        restrs = tuple( [(n2, rep_graph.vc_options([0],[1])) for n2 in others] )
103
104    print 'getting the initial bound'
105    #try:
106    index = tryLoopBound(p_n,p,bound_try_seq,rep,restrs = restrs, hyps=hyps)
107    if index == -1:
108        return None
109    print 'got the initial bound %d' % bound_try_seq[index]
110
111    #do a downward binary search to find the concrete loop bound
112    if index == 0:
113      loop_bound = bound_try_seq[0]
114      print 'bound = %d' % loop_bound
115      return loop_bound
116    loop_bound = downBinSearch(bound_try_seq[index-1], bound_try_seq[index], lambda x: tryLoopBound(p_n,p,[x],rep,restrs=restrs, hyps=hyps, bin_return=True))
117    print 'bound = %d' % loop_bound
118    return loop_bound
119
120def default_n_vc_cases (p, n):
121  head = p.loop_id (n)
122  general = [(n2, rep_graph.vc_options ([0], [1]))
123  for n2 in p.loop_heads ()
124  if n2 != head]
125
126  if head:
127    return [(n, tuple (general + [(head, rep_graph.vc_num (1))])),
128         (n, tuple (general + [(head, rep_graph.vc_offs (1))]))]
129  specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head]
130  return [(n, tuple (general + specific))]
131
132def callNodes(p, fs= None):
133  ns = [n for n in p.nodes if p.nodes[n].kind == 'Call']
134  if fs != None:
135    ns = [n for n in ns if p.nodes[n].fname in fs]
136  return ns
137
138def noHaltHyps(split,p):
139    ret = []
140    all_halts = callNodes(p,fs=['halt'])
141    for x in all_halts:
142       ret += [rep_graph.pc_false_hyp((n_vc, p.node_tags[x][0]))
143         for n_vc in default_n_vc_cases (p, x)]
144    return ret
145
146def tryLoopBound(p_n, p, bounds,rep,restrs =None, hints =None,kind = 'Number',bin_return = False,hyps = None):
147    if restrs == None:
148        restrs = ()
149    if hints == None:
150        hints = []
151    if hyps == None:
152      hyps = []
153    tag = p.node_tags[p_n][0]
154    from stack_logic import default_n_vc
155    print 'trying bound: %s' % bounds
156    ret_bounds = []
157    for (index,i) in enumerate(bounds):
158            print 'testing %d' % i
159            restrs2 = restrs + ((p_n, VisitCount (kind, i)), )
160            try:
161                pc = rep.get_pc ((p_n, restrs2))
162            except:
163                print 'get_pc failed'
164                if bin_return:
165                    return False
166                else:
167                    return -1
168            #print 'got rep_.get_pc'
169            restrs3 = restr_others (p, restrs2, 2)
170            epc = rep.get_pc (('Err', restrs3), tag = tag)
171            hyp = mk_implies (mk_not (epc), mk_not (pc))
172            hyps = hyps + noHaltHyps(p_n,p)
173
174            #hyps = []
175            #print 'calling test_hyp_whyps'
176            if rep.test_hyp_whyps (hyp, hyps):
177                    print 'p_n %d: split limit found: %d' % (p_n, i)
178                    if bin_return:
179                      return True
180                    return index
181    if bin_return:
182      return False
183    print 'loop bound not found!'
184    return -1
185    assert False, 'failed to find loop bound for p_n %d' % p_n
186
187def get_linear_series_eqs (p, split, restrs, hyps, omit_standard = False):
188    k = ('linear_series_eqs', split, restrs, tuple (hyps))
189    if k in p.cached_analysis:
190        if omit_standard:
191          standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False))
192          return set (p.cached_analysis[k]) - standard
193        return p.cached_analysis[k]
194
195    cands = search.mk_seq_eqs (p, split, 1, with_rodata = False)
196    cands += candidate_additional_eqs (p, split)
197    (tag, _) = p.node_tags[split]
198
199    rep = rep_graph.mk_graph_slice (p, fast = True)
200
201    def do_checks (eqs_assume, eqs):
202        checks = (check.single_loop_induct_step_checks (p, restrs, hyps, tag,
203                split, 1, eqs, eqs_assume = eqs_assume)
204            + check.single_loop_induct_base_checks (p, restrs, hyps, tag,
205                split, 1, eqs))
206
207        groups = check.proof_check_groups (checks)
208        for group in groups:
209            (res, _) = check.test_hyp_group (rep, group)
210            if not res:
211                return False
212        return True
213
214    eqs = []
215    failed = []
216    while cands:
217        cand = cands.pop ()
218        if do_checks (eqs, [cand]):
219            eqs.append (cand)
220            failed.reverse ()
221            cands = failed + cands
222            failed = []
223        else:
224            failed.append (cand)
225
226    assert do_checks ([], eqs)
227    p.cached_analysis[k] = eqs
228    if omit_standard:
229      standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False))
230      return set (eqs) - standard
231    return eqs
232
233def get_linear_series_hyps (p, split, restrs, hyps):
234    eqs = get_linear_series_eqs (p, split, restrs, hyps)
235    (tag, _) = p.node_tags[split]
236    hyps = [h for (h, _) in linear_eq_hyps_at_visit (tag, split, eqs,
237             restrs, vc_offs (0))]
238    return hyps
239
240def is_zero (expr):
241    return expr.kind == 'Num' and expr.val & ((1 << expr.typ.num) - 1) == 0
242
243def candidate_additional_eqs (p, split):
244    eq_vals = set ()
245    def visitor (expr):
246      if expr.is_op ('Equals') and expr.vals[0].typ.kind == 'Word':
247        [x, y] = expr.vals
248        eq_vals.update ([(x, y), (y, x)])
249    for n in p.loop_body (split):
250      p.nodes[n].visit (lambda x: (), visitor)
251    for (x, y) in list (eq_vals):
252        if is_zero (x) and y.is_op ('Plus'):
253          [x, y] = y.vals
254          eq_vals.add ((x, syntax.mk_uminus (y)))
255          eq_vals.add ((y, syntax.mk_uminus (x)))
256        elif is_zero (x) and y.is_op ('Minus'):
257          [x, y] = y.vals
258          eq_vals.add ((x, y))
259          eq_vals.add ((y, x))
260
261    loop = syntax.mk_var ('%i', syntax.word32T)
262    minus_loop_step = syntax.mk_uminus (loop)
263
264    vas = search.get_loop_var_analysis_at(p, split)
265    ls_vas = dict ([(var, [data]) for (var, data) in vas
266        if data[0] == 'LoopLinearSeries'])
267    cmp_series = [(x, y, rew, offs) for (x, y) in eq_vals
268        for (_, rew, offs) in ls_vas.get (x, [])]
269    odd_eqs = []
270    for (x, y, rew, offs) in cmp_series:
271        x_init_cmp1 = syntax.mk_less_eq (x, rew (x, minus_loop_step))
272        x_init_cmp2 = syntax.mk_less_eq (rew (x, minus_loop_step), x)
273        fin_cmp1 = syntax.mk_less (x, y)
274        fin_cmp2 = syntax.mk_less (y, x)
275        odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp1))
276        odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp1))
277        odd_eqs.append (syntax.mk_eq (x_init_cmp1, fin_cmp2))
278        odd_eqs.append (syntax.mk_eq (x_init_cmp2, fin_cmp2))
279
280    ass_eqs = []
281    var_deps = p.compute_var_dependencies ()
282    for hook in target_objects.hooks ('extra_wcet_assertions'):
283        for assn in hook (var_deps[split]):
284            ass_eqs.append (assn)
285
286    return odd_eqs + ass_eqs
287
288extra_loop_consts = [2 ** 16]
289
290call_ctxt_problems = []
291
292avoid_C_information = [False]
293
294def get_call_ctxt_problem (split, call_ctxt, timing = True):
295    # time this for diagnostic reasons
296    start = time.time ()
297    from trace_refute import identify_function, build_compound_problem_with_links
298    f = identify_function (call_ctxt, [split])
299    for (ctxt2, p, hyps, addr_map) in call_ctxt_problems:
300      if ctxt2 == (call_ctxt, f):
301        return (p, hyps, addr_map)
302
303    (p, hyps, addr_map) = build_compound_problem_with_links (call_ctxt, f)
304    if avoid_C_information[0]:
305      hyps = [h for h in hyps if not has_C_information (p, h)]
306    call_ctxt_problems.append(((call_ctxt, f), p, hyps, addr_map))
307    del call_ctxt_problems[: -20]
308
309    end = time.time ()
310    if timing:
311      save_extra_timing ('GetProblem', call_ctxt + [split], end - start)
312
313    return (p, hyps, addr_map)
314
315def has_C_information (p, hyp):
316    for (n_vc, tag) in hyp.visits ():
317      if not p.hook_tag_hints.get (tag, None) == 'ASM':
318        return True
319
320known_bound_restr_hyps = {}
321
322known_bounds = {}
323
324def serialise_bound (addr, bound_info):
325    if bound_info == None:
326      return [hex(addr), "None", "None"]
327    else:
328      (bound, kind) = bound_info
329      assert logic.is_int (bound)
330      assert str (kind) == kind
331      return [hex (addr), str (bound), kind]
332
333def save_bound (glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound,
334        time = None):
335    f_names = [trace_refute.get_body_addrs_fun (x)
336      for x in call_ctxt + [split_bin_addr]]
337    loop_name = '<%s>' % ' -> '.join (f_names)
338    comment = '# bound for loop in %s:' % loop_name
339    ss = ['LoopBound'] + serialise_bound (split_bin_addr, bound)
340    if glob:
341      ss[0] = 'GlobalLoopBound'
342    ss += [str (len (call_ctxt))] + map (hex, call_ctxt)
343    ss += [str (prob_hash)]
344    if glob:
345      assert prev_bounds == None
346    else:
347      ss += [str (len (prev_bounds))]
348      for (split, bound) in prev_bounds:
349        ss += serialise_bound (split, bound)
350    s = ' '.join (ss)
351    f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a')
352    f.write (comment + '\n')
353    f.write (s + '\n')
354    if time != None:
355      ctxt2 = call_ctxt + [split_bin_addr]
356      ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2))
357      f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time))
358    f.close ()
359    trace ('Found bound %s for 0x%x in %s.' % (bound, split_bin_addr,
360      loop_name))
361
362def save_extra_timing (nm, ctxt, time):
363    ss = ['ExtraTiming', nm, str (len (ctxt))] + map (hex, ctxt) + [str(time)]
364    f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a')
365    f.write (' '.join (ss) + '\n')
366    f.close ()
367
368def parse_bound (ss, n):
369    addr = syntax.parse_int (ss[n])
370    bound = ss[n + 1]
371    if bound == 'None':
372      bound = None
373      return (n + 3, (addr, None))
374    else:
375      bound = syntax.parse_int (bound)
376      kind = ss[n + 2]
377      return (n + 3, (addr, (bound, kind)))
378
379def parse_ctxt_id (bits, n):
380    return (n + 1, syntax.parse_int (bits[n]))
381
382def parse_ctxt (bits, n):
383    return syntax.parse_list (parse_ctxt_id, bits, n)
384
385def load_bounds ():
386    try:
387      f = open ('%s/LoopBounds.txt' % target_objects.target_dir)
388      ls = list (f)
389      f.close ()
390    except IOError, e:
391      ls = []
392    from syntax import parse_int, parse_list
393    for l in ls:
394      bits = l.split ()
395      if bits[:1] not in [['LoopBound'], ['GlobalLoopBound']]:
396        continue
397      (n, (addr, bound)) = parse_bound (bits, 1)
398      (n, ctxt) = parse_ctxt (bits, n)
399      prob_hash = parse_int (bits[n])
400      n += 1
401      if bits[0] == 'LoopBound':
402        (n, prev_bounds) = parse_list (parse_bound, bits, n)
403        assert n == len (bits), bits
404        known = known_bounds.setdefault (addr, [])
405        known.append ((ctxt, prob_hash, prev_bounds, bound))
406      else:
407        assert n == len (bits), bits
408        known = known_bounds.setdefault ((addr, 'Global'), [])
409        known.append ((ctxt, prob_hash, bound))
410    known_bounds['Loaded'] = True
411
412def get_bound_ctxt (split, call_ctxt, use_cache = True):
413    trace ('Getting bound for 0x%x in context %s.' % (split, call_ctxt))
414    (p, hyps, addr_map) = get_call_ctxt_problem (split, call_ctxt)
415
416    orig_split = split
417    split = p.loop_id (addr_map[split])
418    assert split, (orig_split, call_ctxt)
419    split_bin_addr = min ([addr for addr in addr_map
420        if p.loop_id (addr_map[addr]) == split])
421
422    prior = get_prior_loop_heads (p, split)
423    restrs = ()
424    prev_bounds = []
425    for split2 in prior:
426      # recursion!
427      split2 = p.loop_id (split2)
428      assert split2
429      addr = min ([addr for addr in addr_map
430        if p.loop_id (addr_map[addr]) == split2])
431      bound = get_bound_ctxt (addr, call_ctxt)
432      prev_bounds.append ((addr, bound))
433      k = (p.name, split2, bound, restrs, tuple (hyps))
434      if k in known_bound_restr_hyps:
435        (restrs, hyps) = known_bound_restr_hyps[k]
436      else:
437        (restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps,
438          split2, bound, call_ctxt + [orig_split])
439        known_bound_restr_hyps[k] = (restrs, hyps)
440
441    # start timing now. we miss some setup time, but it avoids double counting
442    # the recursive searches.
443    start = time.time ()
444
445    p_h = problem_hash (p)
446    prev_bounds = sorted (prev_bounds)
447    if not known_bounds:
448      load_bounds ()
449    known = known_bounds.get (split_bin_addr, [])
450    for (call_ctxt2, h, prev_bounds2, bound) in known:
451      match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2)
452      if match and use_cache and h == p_h and prev_bounds2 == prev_bounds:
453        return bound
454    bound = search_bin_bound (p, restrs, hyps, split)
455    known = known_bounds.setdefault (split_bin_addr, [])
456    known.append ((call_ctxt, p_h, prev_bounds, bound))
457    end = time.time ()
458    save_bound (False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound,
459        time = end - start)
460    return bound
461
462def problem_hash (p):
463    return syntax.hash_tuplify ([p.name, p.entries,
464      sorted (p.outputs.iteritems ()), sorted (p.nodes.iteritems ())])
465
466def search_bin_bound (p, restrs, hyps, split):
467    trace ('Searching for bound for 0x%x in %s.', (split, p.name))
468    bound = search_bound (p, restrs, hyps, split)
469    if bound:
470      return bound
471
472    # try to use a bound inferred from C
473    if avoid_C_information[0]:
474      # OK told not to
475      return None
476    if get_prior_loop_heads (p, split):
477      # too difficult for now
478      return None
479    asm_tag = p.node_tags[split][0]
480    (_, fname, _) = p.get_entry_details (asm_tag)
481    funs = [f for pair in target_objects.pairings[fname]
482      for f in pair.funs.values ()]
483    c_tags = [tag for tag in p.tags ()
484        if p.get_entry_details (tag)[1] in funs and tag != asm_tag]
485    if len (c_tags) != 1:
486      print 'Surprised to see multiple matching tags %s' % c_tags
487      return None
488
489    [c_tag] = c_tags
490
491    rep = rep_graph.mk_graph_slice (p)
492    if len (search.get_loop_entry_sites (rep, restrs, hyps, split)) != 1:
493      # technical, but it's not going to work in this case
494      return None
495
496    return getBinaryBoundFromC (p, c_tag, split, restrs, hyps)
497
498def rab_test ():
499    [split_bin_addr] = get_loop_heads (functions['resolveAddressBits'])
500    (p, hyps, addr_map) = get_call_ctxt_problem (split_bin_addr, [])
501    split = p.loop_id (addr_map[split_bin_addr])
502    [c_tag] = [tag for tag in p.tags () if tag != p.node_tags[split][0]]
503    return getBinaryBoundFromC (p, c_tag, split, (), hyps)
504
505last_search_bound = [0]
506
507def search_bound (p, restrs, hyps, split):
508    last_search_bound[0] = (p, restrs, hyps, split)
509
510    # try a naive bin search first
511    # limit this to a small bound for time purposes
512    #   - for larger bounds the less naive approach can be faster
513    bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps,
514      try_seq = [0, 1, 6])
515    if bound != None:
516      return (bound, 'NaiveBinSearch')
517
518    l_hyps = get_linear_series_hyps (p, split, restrs, hyps)
519
520    rep = rep_graph.mk_graph_slice (p, fast = True)
521
522    def test (n):
523        assert n > 10
524        hyp = check.mk_loop_counter_eq_hyp (p, split, restrs, n - 2)
525        visit = ((split, vc_offs (2)), ) + restrs
526        continue_to_split_guess = rep.get_pc ((split, visit))
527        return rep.test_hyp_whyps (syntax.mk_not (continue_to_split_guess),
528          [hyp] + l_hyps + hyps)
529
530    # findLoopBoundBS always checks to at least 16
531    min_bound = 16
532    max_bound = max_acceptable_bound[0]
533    bound = upDownBinSearch (min_bound, max_bound, test)
534    if bound != None and test (bound):
535      return (bound, 'InductiveBinSearch')
536
537    # let the naive bin search go a bit further
538    bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps)
539    if bound != None:
540      return (bound, 'NaiveBinSearch')
541
542    return None
543
544def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps):
545    c_heads = [h for h in search.init_loops_to_split (p, restrs)
546      if p.node_tags[h][0] == c_tag]
547    c_bounds = [(p.loop_id (split), search_bound (p, (), hyps, split))
548      for split in c_heads]
549    if not [b for (n, b) in c_bounds if b]:
550      trace ('no C bounds found (%s).' % c_bounds)
551      return None
552
553    asm_tag = p.node_tags[asm_split][0]
554
555    rep = rep_graph.mk_graph_slice (p)
556    i_seq_opts = [(0, 1), (1, 1), (2, 1)]
557    j_seq_opts = [(0, 1), (0, 2), (1, 1)]
558    tags = [p.node_tags[asm_split][0], c_tag]
559    try:
560      split = search.find_split (rep, asm_split, restrs, hyps, i_seq_opts,
561        j_seq_opts, 5, tags = [asm_tag, c_tag])
562    except solver.SolverFailure, e:
563      return None
564    if not split or split[0] != 'Split':
565      trace ('no split found (%s).' % repr (split))
566      return None
567    (_, split) = split
568    rep = rep_graph.mk_graph_slice (p)
569    checks = check.split_checks (p, (), hyps, split, tags = [asm_tag, c_tag])
570    groups = check.proof_check_groups (checks)
571    try:
572      for group in groups:
573        (res, el) = check.test_hyp_group (rep, group)
574        if not res:
575          trace ('split check failed!')
576          trace ('failed at %s' % el)
577          return None
578    except solver.SolverFailure, e:
579      return None
580    (as_details, c_details, _, n, _) = split
581    (c_split, (seq_start, step), _) = c_details
582    c_bound = dict (c_bounds).get (p.loop_id (c_split))
583    if not c_bound:
584      trace ('key split was not bounded (%r, %r).' % (c_split, c_bounds))
585      return None
586    (c_bound, _) = c_bound
587    max_it = (c_bound - seq_start) / step
588    assert max_it > n, (max_it, n)
589    (_, (seq_start, step), _) = as_details
590    as_bound = seq_start + (max_it * step)
591    # increment by 1 as this may be a bound for a different splitting point
592    # which occurs later in the loop
593    as_bound += 1
594    return (as_bound, 'FromC')
595
596def get_prior_loop_heads (p, split, use_rep = None):
597    if use_rep:
598      rep = use_rep
599    else:
600      rep = rep_graph.mk_graph_slice (p)
601    prior = []
602    split = p.loop_id (split)
603    for h in p.loop_heads ():
604      s = set (prior)
605      if h not in s and rep.get_reachable (h, split) and h != split:
606        # need to recurse to ensure prior are in order
607        prior2 = get_prior_loop_heads (p, h, use_rep = rep)
608        prior.extend ([h2 for h2 in prior2 if h2 not in s])
609        prior.append (h)
610    return prior
611
612def add_loop_bound_restrs_hyps (p, restrs, hyps, split, bound, ctxt):
613    # time this for diagnostic reasons
614    start = time.time ()
615
616    #vc_options([concrete numbers], [offsets])
617    hyps = hyps + get_linear_series_hyps (p, split, restrs, hyps)
618    hyps = list (set (hyps))
619    if bound == None or bound >= 10:
620        restrs = restrs + ((split, rep_graph.vc_options([0],[1])),)
621    else:
622        restrs = restrs + ((split, rep_graph.vc_upto (bound+1)),)
623
624    end = time.time ()
625    save_extra_timing ('LoopBoundRestrHyps', ctxt, end - start)
626
627    return (restrs, hyps)
628
629max_acceptable_bound = [1000000]
630
631functions_hash = [None]
632
633def get_functions_hash ():
634    if functions_hash[0] != None:
635      return functions_hash[0]
636    h = hash (tuple (sorted ([(f, hash (functions[f])) for f in functions])))
637    functions_hash[0] = h
638    return h
639
640addr_to_loop_id_cache = {}
641complex_loop_id_cache = {}
642
643def addr_to_loop_id (split):
644    if split not in addr_to_loop_id_cache:
645      add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split))
646    return addr_to_loop_id_cache[split]
647
648def is_complex_loop (split):
649    split = addr_to_loop_id (split)
650    if split not in complex_loop_id_cache:
651      add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split))
652    return complex_loop_id_cache[split]
653
654def get_loop_addrs (split):
655    split = addr_to_loop_id (split)
656    f = functions[trace_refute.get_body_addrs_fun (split)]
657    return [addr for addr in f.nodes if trace_refute.is_addr (addr)
658        if addr_to_loop_id_cache.get (addr) == split]
659
660def add_fun_to_loop_data_cache (fname):
661    p = functions[fname].as_problem (problem.Problem)
662    p.do_loop_analysis ()
663    for h in p.loop_heads ():
664      addrs = [n for n in p.loop_body (h)
665        if trace_refute.is_addr (n)]
666      min_addr = min (addrs)
667      for addr in addrs:
668        addr_to_loop_id_cache[addr] = min_addr
669      complex_loop_id_cache[min_addr] = problem.has_inner_loop (p, h)
670    return min_addr
671
672def get_bound_super_ctxt (split, call_ctxt, no_splitting=False,
673        known_bound_only=False):
674    if not known_bounds:
675      load_bounds ()
676    for (ctxt2, fn_hash, bound) in known_bounds.get ((split, 'Global'), []):
677      if ctxt2 == call_ctxt and fn_hash == get_functions_hash ():
678        return bound
679    min_loop_addr = addr_to_loop_id (split)
680    if min_loop_addr != split:
681      return get_bound_super_ctxt (min_loop_addr, call_ctxt,
682            no_splitting = no_splitting, known_bound_only = known_bound_only)
683
684    if known_bound_only:
685        return None
686    no_splitting_abort = [False]
687    try:
688      bound = get_bound_super_ctxt_inner (split, call_ctxt,
689        no_splitting = (no_splitting, no_splitting_abort))
690    except problem.Abort, e:
691      bound = None
692    if no_splitting_abort[0]:
693      # don't record this bound, since it might change if splitting was allowed
694      return bound
695    known = known_bounds.setdefault ((split, 'Global'), [])
696    known.append ((call_ctxt, get_functions_hash (), bound))
697    save_bound (True, split, call_ctxt, get_functions_hash (), None, bound)
698    return bound
699
700from trace_refute import (function_limit, ctxt_within_function_limits)
701
702def call_ctxt_computable (split, call_ctxt):
703    fs = [trace_refute.identify_function ([], [call_site])
704      for call_site in call_ctxt]
705    non_computable = [f for f in fs if trace_refute.has_complex_loop (f)]
706    if non_computable:
707      trace ('avoiding functions with complex loops: %s' % non_computable)
708    return not non_computable
709
710def get_bound_super_ctxt_inner (split, call_ctxt,
711      no_splitting = (False, None)):
712    first_f = trace_refute.identify_function ([], (call_ctxt + [split])[:1])
713    call_sites = all_call_sites (first_f)
714
715    if function_limit (first_f) == 0:
716      return (0, 'FunctionLimit')
717    safe_call_sites = [cs for cs in call_sites
718      if ctxt_within_function_limits ([cs] + call_ctxt)]
719    if call_sites and not safe_call_sites:
720      return (0, 'FunctionLimit')
721
722    if len (call_ctxt) < 3 and len (safe_call_sites) == 1:
723      call_ctxt2 = list (safe_call_sites) + call_ctxt
724      if call_ctxt_computable (split, call_ctxt2):
725        trace ('using unique calling context %s' % str ((split, call_ctxt2)))
726        return get_bound_super_ctxt (split, call_ctxt2)
727
728    fname = trace_refute.identify_function (call_ctxt, [split])
729    bound = function_limit_bound (fname, split)
730    if bound:
731      return bound
732
733    bound = get_bound_ctxt (split, call_ctxt)
734    if bound:
735      return bound
736
737    trace ('no bound found immediately.')
738
739    if no_splitting[0]:
740      assert no_splitting[1], no_splitting
741      no_splitting[1][0] = True
742      trace ('cannot split by context (recursion).')
743      return None
744
745    # try to split over potential call sites
746    if len (call_ctxt) >= 3:
747      trace ('cannot split by context (context depth).')
748      return None
749
750    if len (call_sites) == 0:
751      # either entry point or nonsense
752      trace ('cannot split by context (reached top level).')
753      return None
754
755    problem_sites = [call_site for call_site in safe_call_sites
756        if not call_ctxt_computable (split, [call_site] + call_ctxt)]
757    if problem_sites:
758      trace ('cannot split by context (issues in %s).' % problem_sites)
759      return None
760
761    anc_bounds = [get_bound_super_ctxt (split, [call_site] + call_ctxt,
762        no_splitting = True)
763      for call_site in safe_call_sites]
764    if None in anc_bounds:
765      return None
766    (bound, kind) = max (anc_bounds)
767    return (bound, 'MergedBound')
768
769def function_limit_bound (fname, split):
770    p = functions[fname].as_problem (problem.Problem)
771    p.do_analysis ()
772    cuts = [n for n in p.loop_body (split)
773        if p.nodes[n].kind == 'Call'
774        if function_limit (p.nodes[n].fname) != None]
775    if not cuts:
776      return None
777    graph = p.mk_node_graph (p.loop_body (split))
778    # it is not possible to iterate the loop without visiting a bounded
779    # function. naively, this sets the limit to the sum of all the possible
780    # bounds, plus one because we can enter the loop a final time without
781    # visiting any function call site yet.
782    if logic.divides_loop (graph, set (cuts)):
783      fnames = set ([p.nodes[n].fname for n in cuts])
784      return (sum ([function_limit (f) for f in fnames]) + 1, 'FunctionLimit')
785
786def loop_bound_difficulty_estimates (split, ctxt):
787    # various guesses at how hard the loop bounding problem is.
788    (p, hyps, addr_map) = get_call_ctxt_problem (split, ctxt, timing = False)
789
790    loop_id = p.loop_id (addr_map[split])
791    assert loop_id
792
793    # number of instructions in the loop
794    inst_node_ids = set (addr_map.itervalues ())
795    l_insts = [n for n in p.loop_body (loop_id) if n in inst_node_ids]
796
797    # number of instructions in the function
798    tag = p.node_tags[loop_id][0]
799    f_insts = [n for n in inst_node_ids if p.node_tags[n][0] == tag]
800
801    # number of instructions in the whole calling context
802    ctxt_insts = len (inst_node_ids)
803
804    # well, what else?
805    return (len (l_insts), len (f_insts), ctxt_insts)
806
807def load_timing ():
808    f = open ('%s/LoopBounds.txt' % target_objects.target_dir)
809    timing = {}
810    loop_time = 0.0
811    ext_time = 0.0
812    for line in f:
813      bits = line.split ()
814      if not (bits and 'Timing' in bits[0]):
815        continue
816      if bits[0] == 'LoopBoundTiming':
817        (n, ext_ctxt) = parse_ctxt (bits, 1)
818        assert n == len (bits) - 1
819        time = float (bits[n])
820        ctxt = ext_ctxt[:-1]
821        split = ext_ctxt[-1]
822        timing[(split, tuple(ctxt))] = time
823        loop_time += time
824      elif bits[0] == 'ExtraTiming':
825        time = float (bits[-1])
826        ext_time += time
827    f.close ()
828    f = open ('%s/time' % target_objects.target_dir)
829    [l] = [l for l in f if '(wall clock)' in l]
830    f.close ()
831    tot_time_str = l.split ()[-1]
832    tot_time = sum ([float(s) * (60 ** i)
833      for (i, s) in enumerate (reversed (tot_time_str.split(':')))])
834
835    return (loop_time, ext_time, tot_time, timing)
836
837def mk_timing_metrics ():
838    if not known_bounds:
839      load_bounds ()
840    probs = [(split_bin_addr, tuple (call_ctxt), bound)
841      for (split_bin_addr, known) in known_bounds.iteritems ()
842      if type (split_bin_addr) == int
843      for (call_ctxt, h, prev_bounds, bound) in known]
844    probs = set (probs)
845    data = [(split, ctxt, bound,
846        loop_bound_difficulty_estimates (split, list (ctxt)))
847      for (split, ctxt, bound) in probs]
848    return data
849
850# sigh, this is so much work.
851bound_kind_nums = {
852  'FunctionLimit': 2,
853  'NaiveBinSearch': 3,
854  'InductiveBinSearch': 4,
855  'FromC': 5,
856  'MergedBound': 6,
857}
858
859gnuplot_colours = [
860  "dark-red", "dark-blue", "dark-green", "dark-grey",
861  "dark-orange", "dark-magenta", "dark-cyan"]
862
863def save_timing_metrics (num):
864    (loop_time, ext_time, tot_time, timing) = load_timing ()
865
866    col = gnuplot_colours[num]
867    from target import short_name
868
869    time_ests = mk_timing_metrics ()
870    import os
871    f = open ('%s/LoopTimingMetrics.txt' % target_objects.target_dir, 'w')
872    f.write ('"%s"\n' % short_name)
873
874    for (split, ctxt, bound, ests) in time_ests:
875      time = timing[(split, tuple (ctxt))]
876      if bound == None:
877        bdata = "1000000 7"
878      else:
879        bdata = '%d %d' % (bound[0], bound_kind_nums[bound[1]])
880      (l_i, f_i, ct_i) = ests
881      f.write ('%s %s %s %s %s %r %s\n' % (short_name, l_i, f_i, ct_i,
882        bdata, col, time))
883    f.close ()
884
885def get_loop_heads (fun):
886    if not fun.entry:
887      return []
888    p = fun.as_problem (problem.Problem)
889    p.do_loop_analysis ()
890    loops = set ()
891    for h in p.loop_heads ():
892      # any address in the loop will do. pick the smallest one
893      addr = min ([n for n in p.loop_body (h) if trace_refute.is_addr (n)])
894      loops.add ((addr, fun.name, problem.has_inner_loop (p, h)))
895    return list (loops)
896
897def get_all_loop_heads ():
898    loops = set ()
899    abort_funs = set ()
900    for f in all_asm_functions ():
901      try:
902        loops.update (get_loop_heads (functions[f]))
903      except problem.Abort, e:
904        abort_funs.add (f)
905    if abort_funs:
906      trace ('Cannot analyse loops in: %s' % ', '.join (abort_funs))
907    return loops
908
909def get_complex_loops ():
910    return [(loop, name) for (loop, name, compl) in get_all_loop_heads ()
911        if compl]
912
913def search_all_loops ():
914    all_loops = get_all_loop_heads ()
915    for (loop, _, _) in all_loops:
916      get_bound_super_ctxt (loop, [])
917
918main = search_all_loops
919
920if __name__ == '__main__':
921    import sys
922    args = target_objects.load_target_args ()
923    if args == ['search']:
924      search_all_loops ()
925    elif args[:1] == ['metrics']:
926      num = args[1:].index (str (target_objects.target_dir))
927      save_timing_metrics (num)
928
929
930