1# Automatically formatted with yapf (https://github.com/google/yapf)
2"""Utility functions for creating and manipulating LLVM 'opt' NPM pipeline objects."""
3
4
5def fromStr(pipeStr):
6    """Create pipeline object from string representation."""
7    stack = []
8    curr = []
9    tok = ''
10    kind = ''
11    for c in pipeStr:
12        if c == ',':
13            if tok != '':
14                curr.append([None, tok])
15            tok = ''
16        elif c == '(':
17            stack.append([kind, curr])
18            kind = tok
19            curr = []
20            tok = ''
21        elif c == ')':
22            if tok != '':
23                curr.append([None, tok])
24            tok = ''
25            oldKind = kind
26            oldCurr = curr
27            [kind, curr] = stack.pop()
28            curr.append([oldKind, oldCurr])
29        else:
30            tok += c
31    if tok != '':
32        curr.append([None, tok])
33    return curr
34
35
36def toStr(pipeObj):
37    """Create string representation of pipeline object."""
38    res = ''
39    lastIdx = len(pipeObj) - 1
40    for i, c in enumerate(pipeObj):
41        if c[0]:
42            res += c[0] + '('
43            res += toStr(c[1])
44            res += ')'
45        else:
46            res += c[1]
47        if i != lastIdx:
48            res += ','
49    return res
50
51
52def count(pipeObj):
53    """Count number of passes (pass-managers excluded) in pipeline object."""
54    cnt = 0
55    for c in pipeObj:
56        if c[0]:
57            cnt += count(c[1])
58        else:
59            cnt += 1
60    return cnt
61
62
63def split(pipeObj, splitIndex):
64    """Create two new pipeline objects by splitting pipeObj in two directly after pass with index splitIndex."""
65    def splitInt(src, splitIndex, dstA, dstB, idx):
66        for s in src:
67            if s[0]:
68                dstA2 = []
69                dstB2 = []
70                idx = splitInt(s[1], splitIndex, dstA2, dstB2, idx)
71                dstA.append([s[0], dstA2])
72                dstB.append([s[0], dstB2])
73            else:
74                if idx <= splitIndex:
75                    dstA.append([None, s[1]])
76                else:
77                    dstB.append([None, s[1]])
78                idx += 1
79        return idx
80
81    listA = []
82    listB = []
83    splitInt(pipeObj, splitIndex, listA, listB, 0)
84    return [listA, listB]
85
86
87def remove(pipeObj, removeIndex):
88    """Create new pipeline object by removing pass with index removeIndex from pipeObj."""
89    def removeInt(src, removeIndex, dst, idx):
90        for s in src:
91            if s[0]:
92                dst2 = []
93                idx = removeInt(s[1], removeIndex, dst2, idx)
94                dst.append([s[0], dst2])
95            else:
96                if idx != removeIndex:
97                    dst.append([None, s[1]])
98                idx += 1
99        return idx
100
101    dst = []
102    removeInt(pipeObj, removeIndex, dst, 0)
103    return dst
104
105
106def copy(srcPipeObj):
107    """Create copy of pipeline object srcPipeObj."""
108    def copyInt(dst, src):
109        for s in src:
110            if s[0]:
111                dst2 = []
112                copyInt(dst2, s[1])
113                dst.append([s[0], dst2])
114            else:
115                dst.append([None, s[1]])
116
117    dstPipeObj = []
118    copyInt(dstPipeObj, srcPipeObj)
119    return dstPipeObj
120
121
122def prune(srcPipeObj):
123    """Create new pipeline object by removing empty pass-managers (those with count = 0) from srcPipeObj."""
124    def pruneInt(dst, src):
125        for s in src:
126            if s[0]:
127                if count(s[1]):
128                    dst2 = []
129                    pruneInt(dst2, s[1])
130                    dst.append([s[0], dst2])
131            else:
132                dst.append([None, s[1]])
133
134    dstPipeObj = []
135    pruneInt(dstPipeObj, srcPipeObj)
136    return dstPipeObj
137
138
139if __name__ == "__main__":
140    import unittest
141
142    class Test(unittest.TestCase):
143        def test_0(self):
144            pipeStr = 'a,b,A(c,B(d,e),f),g'
145            pipeObj = fromStr(pipeStr)
146
147            self.assertEqual(7, count(pipeObj))
148
149            self.assertEqual(pipeObj, pipeObj)
150            self.assertEqual(pipeObj, prune(pipeObj))
151            self.assertEqual(pipeObj, copy(pipeObj))
152
153            self.assertEqual(pipeStr, toStr(pipeObj))
154            self.assertEqual(pipeStr, toStr(prune(pipeObj)))
155            self.assertEqual(pipeStr, toStr(copy(pipeObj)))
156
157            [pipeObjA, pipeObjB] = split(pipeObj, 3)
158            self.assertEqual('a,b,A(c,B(d))', toStr(pipeObjA))
159            self.assertEqual('A(B(e),f),g', toStr(pipeObjB))
160
161            self.assertEqual('b,A(c,B(d,e),f),g', toStr(remove(pipeObj, 0)))
162            self.assertEqual('a,b,A(c,B(d,e),f)', toStr(remove(pipeObj, 6)))
163
164            pipeObjC = remove(pipeObj, 4)
165            self.assertEqual('a,b,A(c,B(d),f),g', toStr(pipeObjC))
166            pipeObjC = remove(pipeObjC, 3)
167            self.assertEqual('a,b,A(c,B(),f),g', toStr(pipeObjC))
168            pipeObjC = prune(pipeObjC)
169            self.assertEqual('a,b,A(c,f),g', toStr(pipeObjC))
170
171    unittest.main()
172    exit(0)
173