1// SPDX-License-Identifier: GPL-2.0+
2/*
3 *  charset conversion utils
4 *
5 *  Copyright (c) 2017 Rob Clark
6 */
7
8#include <charset.h>
9#include <capitalization.h>
10#include <cp437.h>
11#include <efi_loader.h>
12#include <errno.h>
13#include <malloc.h>
14
15/**
16 * codepage_437 - Unicode to codepage 437 translation table
17 */
18const u16 codepage_437[160] = CP437;
19
20static struct capitalization_table capitalization_table[] =
21#ifdef CONFIG_EFI_UNICODE_CAPITALIZATION
22	UNICODE_CAPITALIZATION_TABLE;
23#elif CONFIG_FAT_DEFAULT_CODEPAGE == 1250
24	CP1250_CAPITALIZATION_TABLE;
25#else
26	CP437_CAPITALIZATION_TABLE;
27#endif
28
29/**
30 * get_code() - read Unicode code point from UTF-8 stream
31 *
32 * @read_u8:	- stream reader
33 * @src:	- string buffer passed to stream reader, optional
34 * Return:	- Unicode code point, or -1
35 */
36static int get_code(u8 (*read_u8)(void *data), void *data)
37{
38	s32 ch = 0;
39
40	ch = read_u8(data);
41	if (!ch)
42		return 0;
43	if (ch >= 0xc2 && ch <= 0xf4) {
44		int code = 0;
45
46		if (ch >= 0xe0) {
47			if (ch >= 0xf0) {
48				/* 0xf0 - 0xf4 */
49				ch &= 0x07;
50				code = ch << 18;
51				ch = read_u8(data);
52				if (ch < 0x80 || ch > 0xbf)
53					goto error;
54				ch &= 0x3f;
55			} else {
56				/* 0xe0 - 0xef */
57				ch &= 0x0f;
58			}
59			code += ch << 12;
60			if ((code >= 0xD800 && code <= 0xDFFF) ||
61			    code >= 0x110000)
62				goto error;
63			ch = read_u8(data);
64			if (ch < 0x80 || ch > 0xbf)
65				goto error;
66		}
67		/* 0xc0 - 0xdf or continuation byte (0x80 - 0xbf) */
68		ch &= 0x3f;
69		code += ch << 6;
70		ch = read_u8(data);
71		if (ch < 0x80 || ch > 0xbf)
72			goto error;
73		ch &= 0x3f;
74		ch += code;
75	} else if (ch >= 0x80) {
76		goto error;
77	}
78	return ch;
79error:
80	return -1;
81}
82
83/**
84 * read_string() - read byte from character string
85 *
86 * @data:	- pointer to string
87 * Return:	- byte read
88 *
89 * The string pointer is incremented if it does not point to '\0'.
90 */
91static u8 read_string(void *data)
92
93{
94	const char **src = (const char **)data;
95	u8 c;
96
97	if (!src || !*src || !**src)
98		return 0;
99	c = **src;
100	++*src;
101	return c;
102}
103
104/**
105 * read_console() - read byte from console
106 *
107 * @data	- not used, needed to match interface
108 * Return:	- byte read or 0 on error
109 */
110static u8 read_console(void *data)
111{
112	int ch;
113
114	ch = getchar();
115	if (ch < 0)
116		ch = 0;
117	return ch;
118}
119
120int console_read_unicode(s32 *code)
121{
122	for (;;) {
123		s32 c;
124
125		if (!tstc()) {
126			/* No input available */
127			return 1;
128		}
129
130		/* Read Unicode code */
131		c = get_code(read_console, NULL);
132		if (c > 0) {
133			*code = c;
134			return 0;
135		}
136	}
137}
138
139s32 utf8_get(const char **src)
140{
141	return get_code(read_string, src);
142}
143
144int utf8_put(s32 code, char **dst)
145{
146	if (!dst || !*dst)
147		return -1;
148	if ((code >= 0xD800 && code <= 0xDFFF) || code >= 0x110000)
149		return -1;
150	if (code <= 0x007F) {
151		**dst = code;
152	} else {
153		if (code <= 0x07FF) {
154			**dst = code >> 6 | 0xC0;
155		} else {
156			if (code < 0x10000) {
157				**dst = code >> 12 | 0xE0;
158			} else {
159				**dst = code >> 18 | 0xF0;
160				++*dst;
161				**dst = (code >> 12 & 0x3F) | 0x80;
162			}
163			++*dst;
164			**dst = (code >> 6 & 0x3F) | 0x80;
165		}
166		++*dst;
167		**dst = (code & 0x3F) | 0x80;
168	}
169	++*dst;
170	return 0;
171}
172
173size_t utf8_utf16_strnlen(const char *src, size_t count)
174{
175	size_t len = 0;
176
177	for (; *src && count; --count)  {
178		s32 code = utf8_get(&src);
179
180		if (!code)
181			break;
182		if (code < 0) {
183			/* Reserve space for a replacement character */
184			len += 1;
185		} else if (code < 0x10000) {
186			len += 1;
187		} else {
188			len += 2;
189		}
190	}
191	return len;
192}
193
194int utf8_utf16_strncpy(u16 **dst, const char *src, size_t count)
195{
196	if (!src || !dst || !*dst)
197		return -1;
198
199	for (; count && *src; --count) {
200		s32 code = utf8_get(&src);
201
202		if (code < 0)
203			code = '?';
204		utf16_put(code, dst);
205	}
206	**dst = 0;
207	return 0;
208}
209
210s32 utf16_get(const u16 **src)
211{
212	s32 code, code2;
213
214	if (!src || !*src)
215		return -1;
216	if (!**src)
217		return 0;
218	code = **src;
219	++*src;
220	if (code >= 0xDC00 && code <= 0xDFFF)
221		return -1;
222	if (code >= 0xD800 && code <= 0xDBFF) {
223		if (!**src)
224			return -1;
225		code &= 0x3ff;
226		code <<= 10;
227		code += 0x10000;
228		code2 = **src;
229		++*src;
230		if (code2 <= 0xDC00 || code2 >= 0xDFFF)
231			return -1;
232		code2 &= 0x3ff;
233		code += code2;
234	}
235	return code;
236}
237
238int utf16_put(s32 code, u16 **dst)
239{
240	if (!dst || !*dst)
241		return -1;
242	if ((code >= 0xD800 && code <= 0xDFFF) || code >= 0x110000)
243		return -1;
244	if (code < 0x10000) {
245		**dst = code;
246	} else {
247		code -= 0x10000;
248		**dst = code >> 10 | 0xD800;
249		++*dst;
250		**dst = (code & 0x3ff) | 0xDC00;
251	}
252	++*dst;
253	return 0;
254}
255
256size_t utf16_strnlen(const u16 *src, size_t count)
257{
258	size_t len = 0;
259
260	for (; *src && count; --count)  {
261		s32 code = utf16_get(&src);
262
263		if (!code)
264			break;
265		/*
266		 * In case of an illegal sequence still reserve space for a
267		 * replacement character.
268		 */
269		++len;
270	}
271	return len;
272}
273
274size_t utf16_utf8_strnlen(const u16 *src, size_t count)
275{
276	size_t len = 0;
277
278	for (; *src && count; --count)  {
279		s32 code = utf16_get(&src);
280
281		if (!code)
282			break;
283		if (code < 0)
284			/* Reserve space for a replacement character */
285			len += 1;
286		else if (code < 0x80)
287			len += 1;
288		else if (code < 0x800)
289			len += 2;
290		else if (code < 0x10000)
291			len += 3;
292		else
293			len += 4;
294	}
295	return len;
296}
297
298int utf16_utf8_strncpy(char **dst, const u16 *src, size_t count)
299{
300	if (!src || !dst || !*dst)
301		return -1;
302
303	for (; count && *src; --count) {
304		s32 code = utf16_get(&src);
305
306		if (code < 0)
307			code = '?';
308		utf8_put(code, dst);
309	}
310	**dst = 0;
311	return 0;
312}
313
314s32 utf_to_lower(const s32 code)
315{
316	struct capitalization_table *pos = capitalization_table;
317	s32 ret = code;
318
319	if (code <= 0x7f) {
320		if (code >= 'A' && code <= 'Z')
321			ret += 0x20;
322		return ret;
323	}
324	for (; pos->upper; ++pos) {
325		if (pos->upper == code) {
326			ret = pos->lower;
327			break;
328		}
329	}
330	return ret;
331}
332
333s32 utf_to_upper(const s32 code)
334{
335	struct capitalization_table *pos = capitalization_table;
336	s32 ret = code;
337
338	if (code <= 0x7f) {
339		if (code >= 'a' && code <= 'z')
340			ret -= 0x20;
341		return ret;
342	}
343	for (; pos->lower; ++pos) {
344		if (pos->lower == code) {
345			ret = pos->upper;
346			break;
347		}
348	}
349	return ret;
350}
351
352/*
353 * u16_strcasecmp() - compare two u16 strings case insensitively
354 *
355 * @s1:		first string to compare
356 * @s2:		second string to compare
357 * @n:		maximum number of u16 to compare
358 * Return:	0  if the first n u16 are the same in s1 and s2
359 *		< 0 if the first different u16 in s1 is less than the
360 *		corresponding u16 in s2
361 *		> 0 if the first different u16 in s1 is greater than the
362 */
363int u16_strcasecmp(const u16 *s1, const u16 *s2)
364{
365	int ret = 0;
366	s32 c1, c2;
367
368	for (;;) {
369		c1 = utf_to_upper(utf16_get(&s1));
370		c2 = utf_to_upper(utf16_get(&s2));
371		ret = c1 - c2;
372		if (ret || !c1 || c1 == -1 || c2 == -1)
373			break;
374	}
375	return ret;
376}
377
378/*
379 * u16_strncmp() - compare two u16 string
380 *
381 * @s1:		first string to compare
382 * @s2:		second string to compare
383 * @n:		maximum number of u16 to compare
384 * Return:	0  if the first n u16 are the same in s1 and s2
385 *		< 0 if the first different u16 in s1 is less than the
386 *		corresponding u16 in s2
387 *		> 0 if the first different u16 in s1 is greater than the
388 *		corresponding u16 in s2
389 */
390int __efi_runtime u16_strncmp(const u16 *s1, const u16 *s2, size_t n)
391{
392	int ret = 0;
393
394	for (; n; --n, ++s1, ++s2) {
395		ret = *s1 - *s2;
396		if (ret || !*s1)
397			break;
398	}
399
400	return ret;
401}
402
403size_t __efi_runtime u16_strnlen(const u16 *in, size_t count)
404{
405	size_t i;
406	for (i = 0; count-- && in[i]; i++);
407	return i;
408}
409
410size_t u16_strsize(const void *in)
411{
412	return (u16_strlen(in) + 1) * sizeof(u16);
413}
414
415u16 *u16_strcpy(u16 *dest, const u16 *src)
416{
417	u16 *tmp = dest;
418
419	for (;; dest++, src++) {
420		*dest = *src;
421		if (!*src)
422			break;
423	}
424
425	return tmp;
426}
427
428u16 *u16_strdup(const void *src)
429{
430	u16 *new;
431	size_t len;
432
433	if (!src)
434		return NULL;
435	len = u16_strsize(src);
436	new = malloc(len);
437	if (!new)
438		return NULL;
439	memcpy(new, src, len);
440
441	return new;
442}
443
444size_t u16_strlcat(u16 *dest, const u16 *src, size_t count)
445{
446	size_t destlen = u16_strnlen(dest, count);
447	size_t srclen = u16_strlen(src);
448	size_t ret = destlen + srclen;
449
450	if (destlen >= count)
451		return ret;
452	if (ret >= count)
453		srclen -= (ret - count + 1);
454	memcpy(&dest[destlen], src, 2 * srclen);
455	dest[destlen + srclen] = 0x0000;
456
457	return ret;
458}
459
460/* Convert UTF-16 to UTF-8.  */
461uint8_t *utf16_to_utf8(uint8_t *dest, const uint16_t *src, size_t size)
462{
463	uint32_t code_high = 0;
464
465	while (size--) {
466		uint32_t code = *src++;
467
468		if (code_high) {
469			if (code >= 0xDC00 && code <= 0xDFFF) {
470				/* Surrogate pair.  */
471				code = ((code_high - 0xD800) << 10) + (code - 0xDC00) + 0x10000;
472
473				*dest++ = (code >> 18) | 0xF0;
474				*dest++ = ((code >> 12) & 0x3F) | 0x80;
475				*dest++ = ((code >> 6) & 0x3F) | 0x80;
476				*dest++ = (code & 0x3F) | 0x80;
477			} else {
478				/* Error...  */
479				*dest++ = '?';
480				/* *src may be valid. Don't eat it.  */
481				src--;
482			}
483
484			code_high = 0;
485		} else {
486			if (code <= 0x007F) {
487				*dest++ = code;
488			} else if (code <= 0x07FF) {
489				*dest++ = (code >> 6) | 0xC0;
490				*dest++ = (code & 0x3F) | 0x80;
491			} else if (code >= 0xD800 && code <= 0xDBFF) {
492				code_high = code;
493				continue;
494			} else if (code >= 0xDC00 && code <= 0xDFFF) {
495				/* Error... */
496				*dest++ = '?';
497			} else if (code < 0x10000) {
498				*dest++ = (code >> 12) | 0xE0;
499				*dest++ = ((code >> 6) & 0x3F) | 0x80;
500				*dest++ = (code & 0x3F) | 0x80;
501			} else {
502				*dest++ = (code >> 18) | 0xF0;
503				*dest++ = ((code >> 12) & 0x3F) | 0x80;
504				*dest++ = ((code >> 6) & 0x3F) | 0x80;
505				*dest++ = (code & 0x3F) | 0x80;
506			}
507		}
508	}
509
510	return dest;
511}
512
513int utf_to_cp(s32 *c, const u16 *codepage)
514{
515	if (*c >= 0x80) {
516		int j;
517
518		/* Look up codepage translation */
519		for (j = 0; j < 0xA0; ++j) {
520			if (*c == codepage[j]) {
521				if (j < 0x20)
522					*c = j;
523				else
524					*c = j + 0x60;
525				return 0;
526			}
527		}
528		*c = '?';
529		return -ENOENT;
530	}
531	return 0;
532}
533
534int utf8_to_cp437_stream(u8 c, char *buffer)
535{
536	char *end;
537	const char *pos;
538	s32 s;
539	int ret;
540
541	for (;;) {
542		pos = buffer;
543		end = buffer + strlen(buffer);
544		*end++ = c;
545		*end = 0;
546		s = utf8_get(&pos);
547		if (s > 0) {
548			*buffer = 0;
549			ret = utf_to_cp(&s, codepage_437);
550			return s;
551			}
552		if (pos == end)
553			return 0;
554		*buffer = 0;
555	}
556}
557
558int utf8_to_utf32_stream(u8 c, char *buffer)
559{
560	char *end;
561	const char *pos;
562	s32 s;
563
564	for (;;) {
565		pos = buffer;
566		end = buffer + strlen(buffer);
567		*end++ = c;
568		*end = 0;
569		s = utf8_get(&pos);
570		if (s > 0) {
571			*buffer = 0;
572			return s;
573		}
574		if (pos == end)
575			return 0;
576		/*
577		 * Appending the byte lead to an invalid UTF-8 byte sequence.
578		 * Consider it as the start of a new code sequence.
579		 */
580		*buffer = 0;
581	}
582}
583