1# SPDX-License-Identifier: GPL-2.0
2
3import builtins
4import inspect
5import sys
6import time
7import traceback
8from .consts import KSFT_MAIN_NAME
9
10KSFT_RESULT = None
11KSFT_RESULT_ALL = True
12
13
14class KsftFailEx(Exception):
15    pass
16
17
18class KsftSkipEx(Exception):
19    pass
20
21
22class KsftXfailEx(Exception):
23    pass
24
25
26def ksft_pr(*objs, **kwargs):
27    print("#", *objs, **kwargs)
28
29
30def _fail(*args):
31    global KSFT_RESULT
32    KSFT_RESULT = False
33
34    frame = inspect.stack()[2]
35    ksft_pr("At " + frame.filename + " line " + str(frame.lineno) + ":")
36    ksft_pr(*args)
37
38
39def ksft_eq(a, b, comment=""):
40    global KSFT_RESULT
41    if a != b:
42        _fail("Check failed", a, "!=", b, comment)
43
44
45def ksft_true(a, comment=""):
46    if not a:
47        _fail("Check failed", a, "does not eval to True", comment)
48
49
50def ksft_in(a, b, comment=""):
51    if a not in b:
52        _fail("Check failed", a, "not in", b, comment)
53
54
55def ksft_ge(a, b, comment=""):
56    if a < b:
57        _fail("Check failed", a, "<", b, comment)
58
59
60class ksft_raises:
61    def __init__(self, expected_type):
62        self.exception = None
63        self.expected_type = expected_type
64
65    def __enter__(self):
66        return self
67
68    def __exit__(self, exc_type, exc_val, exc_tb):
69        if exc_type is None:
70            _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised")
71        elif self.expected_type != exc_type:
72            _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}")
73        self.exception = exc_val
74        # Suppress the exception if its the expected one
75        return self.expected_type == exc_type
76
77
78def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
79    end = time.monotonic() + deadline
80    while True:
81        if cond():
82            return
83        if time.monotonic() > end:
84            _fail("Waiting for condition timed out", comment)
85            return
86        time.sleep(sleep)
87
88
89def ktap_result(ok, cnt=1, case="", comment=""):
90    global KSFT_RESULT_ALL
91    KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
92
93    res = ""
94    if not ok:
95        res += "not "
96    res += "ok "
97    res += str(cnt) + " "
98    res += KSFT_MAIN_NAME
99    if case:
100        res += "." + str(case.__name__)
101    if comment:
102        res += " # " + comment
103    print(res)
104
105
106def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
107    cases = cases or []
108
109    if globs and case_pfx:
110        for key, value in globs.items():
111            if not callable(value):
112                continue
113            for prefix in case_pfx:
114                if key.startswith(prefix):
115                    cases.append(value)
116                    break
117
118    totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
119
120    print("KTAP version 1")
121    print("1.." + str(len(cases)))
122
123    global KSFT_RESULT
124    cnt = 0
125    for case in cases:
126        KSFT_RESULT = True
127        cnt += 1
128        try:
129            case(*args)
130        except KsftSkipEx as e:
131            ktap_result(True, cnt, case, comment="SKIP " + str(e))
132            totals['skip'] += 1
133            continue
134        except KsftXfailEx as e:
135            ktap_result(True, cnt, case, comment="XFAIL " + str(e))
136            totals['xfail'] += 1
137            continue
138        except Exception as e:
139            tb = traceback.format_exc()
140            for line in tb.strip().split('\n'):
141                ksft_pr("Exception|", line)
142            ktap_result(False, cnt, case)
143            totals['fail'] += 1
144            continue
145
146        ktap_result(KSFT_RESULT, cnt, case)
147        if KSFT_RESULT:
148            totals['pass'] += 1
149        else:
150            totals['fail'] += 1
151
152    print(
153        f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
154    )
155
156
157def ksft_exit():
158    global KSFT_RESULT_ALL
159    sys.exit(0 if KSFT_RESULT_ALL else 1)
160