1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * String functions optimized for hardware which doesn't
4 * handle unaligned memory accesses efficiently.
5 *
6 * Copyright (C) 2021 Matteo Croce
7 */
8
9#include <linux/types.h>
10#include <linux/module.h>
11
12/* Minimum size for a word copy to be convenient */
13#define BYTES_LONG	sizeof(long)
14#define WORD_MASK	(BYTES_LONG - 1)
15#define MIN_THRESHOLD	(BYTES_LONG * 2)
16
17/* convenience union to avoid cast between different pointer types */
18union types {
19	u8 *as_u8;
20	unsigned long *as_ulong;
21	uintptr_t as_uptr;
22};
23
24union const_types {
25	const u8 *as_u8;
26	unsigned long *as_ulong;
27	uintptr_t as_uptr;
28};
29
30void *memcpy(void *dest, const void *src, size_t count)
31{
32	union const_types s = { .as_u8 = src };
33	union types d = { .as_u8 = dest };
34	int distance = 0;
35
36	if (count < MIN_THRESHOLD)
37		goto copy_remainder;
38
39	/* Copy a byte at time until destination is aligned. */
40	for (; d.as_uptr & WORD_MASK; count--)
41		*d.as_u8++ = *s.as_u8++;
42
43	distance = s.as_uptr & WORD_MASK;
44
45	if (distance) {
46		unsigned long last, next;
47
48		/*
49		 * s is distance bytes ahead of d, and d just reached
50		 * the alignment boundary. Move s backward to word align it
51		 * and shift data to compensate for distance, in order to do
52		 * word-by-word copy.
53		 */
54		s.as_u8 -= distance;
55
56		next = s.as_ulong[0];
57		for (; count >= BYTES_LONG; count -= BYTES_LONG) {
58			last = next;
59			next = s.as_ulong[1];
60
61			d.as_ulong[0] = last >> (distance * 8) |
62				next << ((BYTES_LONG - distance) * 8);
63
64			d.as_ulong++;
65			s.as_ulong++;
66		}
67
68		/* Restore s with the original offset. */
69		s.as_u8 += distance;
70	} else {
71		/*
72		 * If the source and dest lower bits are the same, do a simple
73		 * 32/64 bit wide copy.
74		 */
75		for (; count >= BYTES_LONG; count -= BYTES_LONG)
76			*d.as_ulong++ = *s.as_ulong++;
77	}
78
79copy_remainder:
80	while (count--)
81		*d.as_u8++ = *s.as_u8++;
82
83	return dest;
84}
85EXPORT_SYMBOL(memcpy);
86
87/*
88 * Simply check if the buffer overlaps an call memcpy() in case,
89 * otherwise do a simple one byte at time backward copy.
90 */
91void *memmove(void *dest, const void *src, size_t count)
92{
93	if (dest < src || src + count <= dest)
94		return memcpy(dest, src, count);
95
96	if (dest > src) {
97		const char *s = src + count;
98		char *tmp = dest + count;
99
100		while (count--)
101			*--tmp = *--s;
102	}
103	return dest;
104}
105EXPORT_SYMBOL(memmove);
106
107void *memset(void *s, int c, size_t count)
108{
109	union types dest = { .as_u8 = s };
110
111	if (count >= MIN_THRESHOLD) {
112		unsigned long cu = (unsigned long)c;
113
114		/* Compose an ulong with 'c' repeated 4/8 times */
115		cu |= cu << 8;
116		cu |= cu << 16;
117		/* Suppress warning on 32 bit machines */
118		cu |= (cu << 16) << 16;
119
120		for (; count && dest.as_uptr & WORD_MASK; count--)
121			*dest.as_u8++ = c;
122
123		/* Copy using the largest size allowed */
124		for (; count >= BYTES_LONG; count -= BYTES_LONG)
125			*dest.as_ulong++ = cu;
126	}
127
128	/* copy the remainder */
129	while (count--)
130		*dest.as_u8++ = c;
131
132	return s;
133}
134EXPORT_SYMBOL(memset);
135