1/*-
2 * Copyright (c) 2018, Juniper Networks, Inc.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 * 1. Redistributions of source code must retain the above copyright
8 *    notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 *    notice, this list of conditions and the following disclaimer in the
11 *    documentation and/or other materials provided with the distribution.
12 *
13 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
14 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
15 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
16 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
17 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
18 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
19 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include <sys/cdefs.h>
27__FBSDID("$FreeBSD$");
28
29#include <libsecureboot.h>
30
31#include "decode.h"
32
33char *
34octets2hex(unsigned char *ptr, size_t n)
35{
36	char *hex;
37	char *cp;
38	size_t i;
39
40	hex = malloc(2 * n + 1);
41	if (hex != NULL) {
42		for (i = 0, cp = hex; i < n; i++) {
43			snprintf(&cp[i*2], 3, "%02X", ptr[i]);
44		}
45	}
46	return (hex);
47}
48
49unsigned char *
50i2octets(int n, size_t i)
51{
52	static unsigned char o[16];
53	int x, j;
54
55	if (n > 15)
56		return (NULL);
57	for (j = 0, x = n - 1; x >= 0; x--, j++) {
58		o[j] = (unsigned char)((i & (0xff << x * 8)) >> x * 8);
59	}
60	return (o);
61}
62
63int
64octets2i(unsigned char *ptr, size_t n)
65{
66	size_t i;
67	int val;
68
69	for (val = i = 0; i < n; i++) {
70		val |= (*ptr++ << ((n - i - 1) * 8));
71	}
72	return (val);
73}
74
75/**
76 * @brief decode packet tag
77 *
78 * Also indicate if new/old and in the later case
79 * the length type
80 *
81 * @sa rfc4880:4.2
82 */
83int
84decode_tag(unsigned char *ptr, int *isnew, int *ltype)
85{
86	int tag;
87
88	if (!ptr || !isnew || !ltype)
89		return (-1);
90	tag = *ptr;
91
92	if (!(tag & OPENPGP_TAG_ISTAG))
93		return (-1);		/* we are lost! */
94	*isnew = tag & OPENPGP_TAG_ISNEW;
95	if (*isnew) {
96		*ltype = -1;		/* irrelevant */
97		tag &= OPENPGP_TAG_NEW_MASK;
98	} else {
99		*ltype = tag & OPENPGP_TAG_OLD_TYPE;
100		tag = (tag & OPENPGP_TAG_OLD_MASK) >> 2;
101	}
102	return (tag);
103}
104
105/**
106 * @brief return packet length
107 *
108 * @sa rfc4880:4.2.2
109 */
110static int
111decode_new_len(unsigned char **pptr)
112{
113	unsigned char *ptr;
114	int len = -1;
115
116	if (pptr == NULL)
117		return (-1);
118	ptr = *pptr;
119
120	if (!(*ptr < 224 || *ptr == 255))
121		return (-1);		/* not supported */
122
123	if (*ptr < 192)
124		len = *ptr++;
125	else if (*ptr < 224) {
126		len = ((*ptr - 192) << 8) + *(ptr+1) + 192;
127		ptr++;
128	} else if (*ptr == 255) {
129		len = (*ptr++ << 24);
130		len |= (*ptr++ << 16);
131		len |= (*ptr++ < 8);
132		len |= *ptr++;
133	}
134
135	*pptr = ptr;
136	return (len);
137}
138
139/**
140 * @brief return packet length
141 *
142 * @sa rfc4880:4.2.1
143 */
144static int
145decode_len(unsigned char **pptr, int ltype)
146{
147	unsigned char *ptr;
148	int len;
149
150	if (ltype < 0)
151		return (decode_new_len(pptr));
152
153	if (pptr == NULL)
154		return (-1);
155
156	ptr = *pptr;
157
158	switch (ltype) {
159	case 0:
160		len = *ptr++;
161		break;
162	case 1:
163		len = (*ptr++ << 8);
164		len |= *ptr++;
165		break;
166	case 2:
167		len =  *ptr++ << 24;
168		len |= *ptr++ << 16;
169		len |= *ptr++ << 8;
170		len |= *ptr++;
171		break;
172	case 3:
173	default:
174		/* Not supported */
175		len = -1;
176	}
177
178	*pptr = ptr;
179	return (len);
180}
181
182/**
183 * @brief return pointer and length of an mpi
184 *
185 * @sa rfc4880:3.2
186 */
187unsigned char *
188decode_mpi(unsigned char **pptr, size_t *sz)
189{
190	unsigned char *data;
191	unsigned char *ptr;
192	size_t mlen;
193
194	if (pptr == NULL || sz == NULL)
195		return (NULL);
196
197	ptr = *pptr;
198
199	mlen = (size_t)(*ptr++ << 8);
200	mlen |= (size_t)*ptr++;		/* number of bits */
201	mlen = (mlen + 7) / 8;		/* number of bytes */
202	*sz = mlen;
203	data = ptr;
204	ptr += mlen;
205	*pptr = ptr;
206	return (data);
207}
208
209/**
210 * @brief return an OpenSSL BIGNUM from mpi
211 *
212 * @sa rfc4880:3.2
213 */
214#ifdef USE_BEARSSL
215unsigned char *
216mpi2bn(unsigned char **pptr, size_t *sz)
217{
218	return (decode_mpi(pptr, sz));
219}
220#else
221BIGNUM *
222mpi2bn(unsigned char **pptr)
223{
224	BIGNUM *bn = NULL;
225	unsigned char *ptr;
226	int mlen;
227
228	if (pptr == NULL)
229		return (NULL);
230
231	ptr = *pptr;
232
233	mlen = (*ptr++ << 8);
234	mlen |= *ptr++;			/* number of bits */
235	mlen = (mlen + 7) / 8;		/* number of bytes */
236	bn = BN_bin2bn(ptr, mlen, NULL);
237	ptr += mlen;
238	*pptr = ptr;
239
240	return (bn);
241}
242#endif
243
244/**
245 * @brief decode a packet
246 *
247 * If want is set, check that the packet tag matches
248 * if all good, call the provided decoder with its arg
249 *
250 * @return count of unconsumed data
251 *
252 * @sa rfc4880:4.2
253 */
254int
255decode_packet(int want, unsigned char **pptr, size_t nbytes,
256    decoder_t decoder, void *decoder_arg)
257{
258	int tag;
259	unsigned char *ptr;
260	unsigned char *nptr;
261	int isnew, ltype;
262	int len;
263	int hlen;
264	int rc = 0;
265
266	nptr = ptr = *pptr;
267
268	tag = decode_tag(ptr, &isnew, &ltype);
269
270	if (want > 0 && tag != want)
271		return (-1);
272	ptr++;
273
274	len = rc = decode_len(&ptr, ltype);
275	hlen = (int)(ptr - nptr);
276	nptr = ptr + len;		/* consume it */
277
278	if (decoder)
279		rc = decoder(tag, &ptr, len, decoder_arg);
280	*pptr = nptr;
281	nbytes -= (size_t)(hlen + len);
282	if (rc < 0)
283		return (rc);		/* error */
284	return ((int)nbytes);		/* unconsumed data */
285}
286
287/**
288 * @brief decode a sub packet
289 *
290 * @sa rfc4880:5.2.3.1
291 */
292unsigned char *
293decode_subpacket(unsigned char **pptr, int *stag, int *sz)
294{
295	unsigned char *ptr;
296	int len;
297
298	ptr = *pptr;
299	len = decode_len(&ptr, -1);
300	*sz = (int)(len + ptr - *pptr);
301	*pptr = ptr + len;
302	*stag = *ptr++;
303	return (ptr);
304}
305