1# SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause)
2
3from argparse import ArgumentParser
4from argparse import FileType
5import os
6import sys
7import tpm2
8from tpm2 import ProtocolError
9import unittest
10import logging
11import struct
12
13class SmokeTest(unittest.TestCase):
14    def setUp(self):
15        self.client = tpm2.Client()
16        self.root_key = self.client.create_root_key()
17
18    def tearDown(self):
19        self.client.flush_context(self.root_key)
20        self.client.close()
21
22    def test_seal_with_auth(self):
23        data = ('X' * 64).encode()
24        auth = ('A' * 15).encode()
25
26        blob = self.client.seal(self.root_key, data, auth, None)
27        result = self.client.unseal(self.root_key, blob, auth, None)
28        self.assertEqual(data, result)
29
30    def determine_bank_alg(self, mask):
31        pcr_banks = self.client.get_cap_pcrs()
32        for bank_alg, pcrSelection in pcr_banks.items():
33            if pcrSelection & mask == mask:
34                return bank_alg
35        return None
36
37    def test_seal_with_policy(self):
38        bank_alg = self.determine_bank_alg(1 << 16)
39        self.assertIsNotNone(bank_alg)
40
41        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
42
43        data = ('X' * 64).encode()
44        auth = ('A' * 15).encode()
45        pcrs = [16]
46
47        try:
48            self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg)
49            self.client.policy_password(handle)
50
51            policy_dig = self.client.get_policy_digest(handle)
52        finally:
53            self.client.flush_context(handle)
54
55        blob = self.client.seal(self.root_key, data, auth, policy_dig)
56
57        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
58
59        try:
60            self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg)
61            self.client.policy_password(handle)
62
63            result = self.client.unseal(self.root_key, blob, auth, handle)
64        except:
65            self.client.flush_context(handle)
66            raise
67
68        self.assertEqual(data, result)
69
70    def test_unseal_with_wrong_auth(self):
71        data = ('X' * 64).encode()
72        auth = ('A' * 20).encode()
73        rc = 0
74
75        blob = self.client.seal(self.root_key, data, auth, None)
76        try:
77            result = self.client.unseal(self.root_key, blob,
78                        auth[:-1] + 'B'.encode(), None)
79        except ProtocolError as e:
80            rc = e.rc
81
82        self.assertEqual(rc, tpm2.TPM2_RC_AUTH_FAIL)
83
84    def test_unseal_with_wrong_policy(self):
85        bank_alg = self.determine_bank_alg(1 << 16 | 1 << 1)
86        self.assertIsNotNone(bank_alg)
87
88        handle = self.client.start_auth_session(tpm2.TPM2_SE_TRIAL)
89
90        data = ('X' * 64).encode()
91        auth = ('A' * 17).encode()
92        pcrs = [16]
93
94        try:
95            self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg)
96            self.client.policy_password(handle)
97
98            policy_dig = self.client.get_policy_digest(handle)
99        finally:
100            self.client.flush_context(handle)
101
102        blob = self.client.seal(self.root_key, data, auth, policy_dig)
103
104        # Extend first a PCR that is not part of the policy and try to unseal.
105        # This should succeed.
106
107        ds = tpm2.get_digest_size(bank_alg)
108        self.client.extend_pcr(1, ('X' * ds).encode(), bank_alg=bank_alg)
109
110        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
111
112        try:
113            self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg)
114            self.client.policy_password(handle)
115
116            result = self.client.unseal(self.root_key, blob, auth, handle)
117        except:
118            self.client.flush_context(handle)
119            raise
120
121        self.assertEqual(data, result)
122
123        # Then, extend a PCR that is part of the policy and try to unseal.
124        # This should fail.
125        self.client.extend_pcr(16, ('X' * ds).encode(), bank_alg=bank_alg)
126
127        handle = self.client.start_auth_session(tpm2.TPM2_SE_POLICY)
128
129        rc = 0
130
131        try:
132            self.client.policy_pcr(handle, pcrs, bank_alg=bank_alg)
133            self.client.policy_password(handle)
134
135            result = self.client.unseal(self.root_key, blob, auth, handle)
136        except ProtocolError as e:
137            rc = e.rc
138            self.client.flush_context(handle)
139        except:
140            self.client.flush_context(handle)
141            raise
142
143        self.assertEqual(rc, tpm2.TPM2_RC_POLICY_FAIL)
144
145    def test_seal_with_too_long_auth(self):
146        ds = tpm2.get_digest_size(tpm2.TPM2_ALG_SHA1)
147        data = ('X' * 64).encode()
148        auth = ('A' * (ds + 1)).encode()
149
150        rc = 0
151        try:
152            blob = self.client.seal(self.root_key, data, auth, None)
153        except ProtocolError as e:
154            rc = e.rc
155
156        self.assertEqual(rc, tpm2.TPM2_RC_SIZE)
157
158    def test_too_short_cmd(self):
159        rejected = False
160        try:
161            fmt = '>HIII'
162            cmd = struct.pack(fmt,
163                              tpm2.TPM2_ST_NO_SESSIONS,
164                              struct.calcsize(fmt) + 1,
165                              tpm2.TPM2_CC_FLUSH_CONTEXT,
166                              0xDEADBEEF)
167
168            self.client.send_cmd(cmd)
169        except IOError as e:
170            rejected = True
171        except:
172            pass
173        self.assertEqual(rejected, True)
174
175    def test_read_partial_resp(self):
176        try:
177            fmt = '>HIIH'
178            cmd = struct.pack(fmt,
179                              tpm2.TPM2_ST_NO_SESSIONS,
180                              struct.calcsize(fmt),
181                              tpm2.TPM2_CC_GET_RANDOM,
182                              0x20)
183            self.client.tpm.write(cmd)
184            hdr = self.client.tpm.read(10)
185            sz = struct.unpack('>I', hdr[2:6])[0]
186            rsp = self.client.tpm.read()
187        except:
188            pass
189        self.assertEqual(sz, 10 + 2 + 32)
190        self.assertEqual(len(rsp), 2 + 32)
191
192    def test_read_partial_overwrite(self):
193        try:
194            fmt = '>HIIH'
195            cmd = struct.pack(fmt,
196                              tpm2.TPM2_ST_NO_SESSIONS,
197                              struct.calcsize(fmt),
198                              tpm2.TPM2_CC_GET_RANDOM,
199                              0x20)
200            self.client.tpm.write(cmd)
201            # Read part of the respone
202            rsp1 = self.client.tpm.read(15)
203
204            # Send a new cmd
205            self.client.tpm.write(cmd)
206
207            # Read the whole respone
208            rsp2 = self.client.tpm.read()
209        except:
210            pass
211        self.assertEqual(len(rsp1), 15)
212        self.assertEqual(len(rsp2), 10 + 2 + 32)
213
214    def test_send_two_cmds(self):
215        rejected = False
216        try:
217            fmt = '>HIIH'
218            cmd = struct.pack(fmt,
219                              tpm2.TPM2_ST_NO_SESSIONS,
220                              struct.calcsize(fmt),
221                              tpm2.TPM2_CC_GET_RANDOM,
222                              0x20)
223            self.client.tpm.write(cmd)
224
225            # expect the second one to raise -EBUSY error
226            self.client.tpm.write(cmd)
227            rsp = self.client.tpm.read()
228
229        except IOError as e:
230            # read the response
231            rsp = self.client.tpm.read()
232            rejected = True
233            pass
234        except:
235            pass
236        self.assertEqual(rejected, True)
237
238class SpaceTest(unittest.TestCase):
239    def setUp(self):
240        logging.basicConfig(filename='SpaceTest.log', level=logging.DEBUG)
241
242    def test_make_two_spaces(self):
243        log = logging.getLogger(__name__)
244        log.debug("test_make_two_spaces")
245
246        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
247        root1 = space1.create_root_key()
248        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
249        root2 = space2.create_root_key()
250        root3 = space2.create_root_key()
251
252        log.debug("%08x" % (root1))
253        log.debug("%08x" % (root2))
254        log.debug("%08x" % (root3))
255
256    def test_flush_context(self):
257        log = logging.getLogger(__name__)
258        log.debug("test_flush_context")
259
260        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
261        root1 = space1.create_root_key()
262        log.debug("%08x" % (root1))
263
264        space1.flush_context(root1)
265
266    def test_get_handles(self):
267        log = logging.getLogger(__name__)
268        log.debug("test_get_handles")
269
270        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
271        space1.create_root_key()
272        space2 = tpm2.Client(tpm2.Client.FLAG_SPACE)
273        space2.create_root_key()
274        space2.create_root_key()
275
276        handles = space2.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_TRANSIENT)
277
278        self.assertEqual(len(handles), 2)
279
280        log.debug("%08x" % (handles[0]))
281        log.debug("%08x" % (handles[1]))
282
283    def test_invalid_cc(self):
284        log = logging.getLogger(__name__)
285        log.debug(sys._getframe().f_code.co_name)
286
287        TPM2_CC_INVALID = tpm2.TPM2_CC_FIRST - 1
288
289        space1 = tpm2.Client(tpm2.Client.FLAG_SPACE)
290        root1 = space1.create_root_key()
291        log.debug("%08x" % (root1))
292
293        fmt = '>HII'
294        cmd = struct.pack(fmt, tpm2.TPM2_ST_NO_SESSIONS, struct.calcsize(fmt),
295                          TPM2_CC_INVALID)
296
297        rc = 0
298        try:
299            space1.send_cmd(cmd)
300        except ProtocolError as e:
301            rc = e.rc
302
303        self.assertEqual(rc, tpm2.TPM2_RC_COMMAND_CODE |
304                         tpm2.TSS2_RESMGR_TPM_RC_LAYER)
305
306class AsyncTest(unittest.TestCase):
307    def setUp(self):
308        logging.basicConfig(filename='AsyncTest.log', level=logging.DEBUG)
309
310    def test_async(self):
311        log = logging.getLogger(__name__)
312        log.debug(sys._getframe().f_code.co_name)
313
314        async_client = tpm2.Client(tpm2.Client.FLAG_NONBLOCK)
315        log.debug("Calling get_cap in a NON_BLOCKING mode")
316        async_client.get_cap(tpm2.TPM2_CAP_HANDLES, tpm2.HR_LOADED_SESSION)
317        async_client.close()
318
319    def test_flush_invalid_context(self):
320        log = logging.getLogger(__name__)
321        log.debug(sys._getframe().f_code.co_name)
322
323        async_client = tpm2.Client(tpm2.Client.FLAG_SPACE | tpm2.Client.FLAG_NONBLOCK)
324        log.debug("Calling flush_context passing in an invalid handle ")
325        handle = 0x80123456
326        rc = 0
327        try:
328            async_client.flush_context(handle)
329        except OSError as e:
330            rc = e.errno
331
332        self.assertEqual(rc, 22)
333        async_client.close()
334