1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright 2023 Red Hat
4 */
5
6#include "radix-sort.h"
7
8#include <linux/limits.h>
9#include <linux/types.h>
10
11#include "memory-alloc.h"
12#include "string-utils.h"
13
14/*
15 * This implementation allocates one large object to do the sorting, which can be reused as many
16 * times as desired. The amount of memory required is logarithmically proportional to the number of
17 * keys to be sorted.
18 */
19
20/* Piles smaller than this are handled with a simple insertion sort. */
21#define INSERTION_SORT_THRESHOLD 12
22
23/* Sort keys are pointers to immutable fixed-length arrays of bytes. */
24typedef const u8 *sort_key_t;
25
26/*
27 * The keys are separated into piles based on the byte in each keys at the current offset, so the
28 * number of keys with each byte must be counted.
29 */
30struct histogram {
31	/* The number of non-empty bins */
32	u16 used;
33	/* The index (key byte) of the first non-empty bin */
34	u16 first;
35	/* The index (key byte) of the last non-empty bin */
36	u16 last;
37	/* The number of occurrences of each specific byte */
38	u32 size[256];
39};
40
41/*
42 * Sub-tasks are manually managed on a stack, both for performance and to put a logarithmic bound
43 * on the stack space needed.
44 */
45struct task {
46	/* Pointer to the first key to sort. */
47	sort_key_t *first_key;
48	/* Pointer to the last key to sort. */
49	sort_key_t *last_key;
50	/* The offset into the key at which to continue sorting. */
51	u16 offset;
52	/* The number of bytes remaining in the sort keys. */
53	u16 length;
54};
55
56struct radix_sorter {
57	unsigned int count;
58	struct histogram bins;
59	sort_key_t *pile[256];
60	struct task *end_of_stack;
61	struct task insertion_list[256];
62	struct task stack[];
63};
64
65/* Compare a segment of two fixed-length keys starting at an offset. */
66static inline int compare(sort_key_t key1, sort_key_t key2, u16 offset, u16 length)
67{
68	return memcmp(&key1[offset], &key2[offset], length);
69}
70
71/* Insert the next unsorted key into an array of sorted keys. */
72static inline void insert_key(const struct task task, sort_key_t *next)
73{
74	/* Pull the unsorted key out, freeing up the array slot. */
75	sort_key_t unsorted = *next;
76
77	/* Compare the key to the preceding sorted entries, shifting down ones that are larger. */
78	while ((--next >= task.first_key) &&
79	       (compare(unsorted, next[0], task.offset, task.length) < 0))
80		next[1] = next[0];
81
82	/* Insert the key into the last slot that was cleared, sorting it. */
83	next[1] = unsorted;
84}
85
86/*
87 * Sort a range of key segments using an insertion sort. This simple sort is faster than the
88 * 256-way radix sort when the number of keys to sort is small.
89 */
90static inline void insertion_sort(const struct task task)
91{
92	sort_key_t *next;
93
94	for (next = task.first_key + 1; next <= task.last_key; next++)
95		insert_key(task, next);
96}
97
98/* Push a sorting task onto a task stack. */
99static inline void push_task(struct task **stack_pointer, sort_key_t *first_key,
100			     u32 count, u16 offset, u16 length)
101{
102	struct task *task = (*stack_pointer)++;
103
104	task->first_key = first_key;
105	task->last_key = &first_key[count - 1];
106	task->offset = offset;
107	task->length = length;
108}
109
110static inline void swap_keys(sort_key_t *a, sort_key_t *b)
111{
112	sort_key_t c = *a;
113	*a = *b;
114	*b = c;
115}
116
117/*
118 * Count the number of times each byte value appears in the arrays of keys to sort at the current
119 * offset, keeping track of the number of non-empty bins, and the index of the first and last
120 * non-empty bin.
121 */
122static inline void measure_bins(const struct task task, struct histogram *bins)
123{
124	sort_key_t *key_ptr;
125
126	/*
127	 * Subtle invariant: bins->used and bins->size[] are zero because the sorting code clears
128	 * it all out as it goes. Even though this structure is re-used, we don't need to pay to
129	 * zero it before starting a new tally.
130	 */
131	bins->first = U8_MAX;
132	bins->last = 0;
133
134	for (key_ptr = task.first_key; key_ptr <= task.last_key; key_ptr++) {
135		/* Increment the count for the byte in the key at the current offset. */
136		u8 bin = (*key_ptr)[task.offset];
137		u32 size = ++bins->size[bin];
138
139		/* Track non-empty bins. */
140		if (size == 1) {
141			bins->used += 1;
142			if (bin < bins->first)
143				bins->first = bin;
144
145			if (bin > bins->last)
146				bins->last = bin;
147		}
148	}
149}
150
151/*
152 * Convert the bin sizes to pointers to where each pile goes.
153 *
154 *   pile[0] = first_key + bin->size[0],
155 *   pile[1] = pile[0]  + bin->size[1], etc.
156 *
157 * After the keys are moved to the appropriate pile, we'll need to sort each of the piles by the
158 * next radix position. A new task is put on the stack for each pile containing lots of keys, or a
159 * new task is put on the list for each pile containing few keys.
160 *
161 * @stack: pointer the top of the stack
162 * @end_of_stack: the end of the stack
163 * @list: pointer the head of the list
164 * @pile: array for pointers to the end of each pile
165 * @bins: the histogram of the sizes of each pile
166 * @first_key: the first key of the stack
167 * @offset: the next radix position to sort by
168 * @length: the number of bytes remaining in the sort keys
169 *
170 * Return: UDS_SUCCESS or an error code
171 */
172static inline int push_bins(struct task **stack, struct task *end_of_stack,
173			    struct task **list, sort_key_t *pile[],
174			    struct histogram *bins, sort_key_t *first_key,
175			    u16 offset, u16 length)
176{
177	sort_key_t *pile_start = first_key;
178	int bin;
179
180	for (bin = bins->first; ; bin++) {
181		u32 size = bins->size[bin];
182
183		/* Skip empty piles. */
184		if (size == 0)
185			continue;
186
187		/* There's no need to sort empty keys. */
188		if (length > 0) {
189			if (size > INSERTION_SORT_THRESHOLD) {
190				if (*stack >= end_of_stack)
191					return UDS_BAD_STATE;
192
193				push_task(stack, pile_start, size, offset, length);
194			} else if (size > 1) {
195				push_task(list, pile_start, size, offset, length);
196			}
197		}
198
199		pile_start += size;
200		pile[bin] = pile_start;
201		if (--bins->used == 0)
202			break;
203	}
204
205	return UDS_SUCCESS;
206}
207
208int uds_make_radix_sorter(unsigned int count, struct radix_sorter **sorter)
209{
210	int result;
211	unsigned int stack_size = count / INSERTION_SORT_THRESHOLD;
212	struct radix_sorter *radix_sorter;
213
214	result = vdo_allocate_extended(struct radix_sorter, stack_size, struct task,
215				       __func__, &radix_sorter);
216	if (result != VDO_SUCCESS)
217		return result;
218
219	radix_sorter->count = count;
220	radix_sorter->end_of_stack = radix_sorter->stack + stack_size;
221	*sorter = radix_sorter;
222	return UDS_SUCCESS;
223}
224
225void uds_free_radix_sorter(struct radix_sorter *sorter)
226{
227	vdo_free(sorter);
228}
229
230/*
231 * Sort pointers to fixed-length keys (arrays of bytes) using a radix sort. The sort implementation
232 * is unstable, so the relative ordering of equal keys is not preserved.
233 */
234int uds_radix_sort(struct radix_sorter *sorter, const unsigned char *keys[],
235		   unsigned int count, unsigned short length)
236{
237	struct task start;
238	struct histogram *bins = &sorter->bins;
239	sort_key_t **pile = sorter->pile;
240	struct task *task_stack = sorter->stack;
241
242	/* All zero-length keys are identical and therefore already sorted. */
243	if ((count == 0) || (length == 0))
244		return UDS_SUCCESS;
245
246	/* The initial task is to sort the entire length of all the keys. */
247	start = (struct task) {
248		.first_key = keys,
249		.last_key = &keys[count - 1],
250		.offset = 0,
251		.length = length,
252	};
253
254	if (count <= INSERTION_SORT_THRESHOLD) {
255		insertion_sort(start);
256		return UDS_SUCCESS;
257	}
258
259	if (count > sorter->count)
260		return UDS_INVALID_ARGUMENT;
261
262	/*
263	 * Repeatedly consume a sorting task from the stack and process it, pushing new sub-tasks
264	 * onto the stack for each radix-sorted pile. When all tasks and sub-tasks have been
265	 * processed, the stack will be empty and all the keys in the starting task will be fully
266	 * sorted.
267	 */
268	for (*task_stack = start; task_stack >= sorter->stack; task_stack--) {
269		const struct task task = *task_stack;
270		struct task *insertion_task_list;
271		int result;
272		sort_key_t *fence;
273		sort_key_t *end;
274
275		measure_bins(task, bins);
276
277		/*
278		 * Now that we know how large each bin is, generate pointers for each of the piles
279		 * and push a new task to sort each pile by the next radix byte.
280		 */
281		insertion_task_list = sorter->insertion_list;
282		result = push_bins(&task_stack, sorter->end_of_stack,
283				   &insertion_task_list, pile, bins, task.first_key,
284				   task.offset + 1, task.length - 1);
285		if (result != UDS_SUCCESS) {
286			memset(bins, 0, sizeof(*bins));
287			return result;
288		}
289
290		/* Now bins->used is zero again. */
291
292		/*
293		 * Don't bother processing the last pile: when piles 0..N-1 are all in place, then
294		 * pile N must also be in place.
295		 */
296		end = task.last_key - bins->size[bins->last];
297		bins->size[bins->last] = 0;
298
299		for (fence = task.first_key; fence <= end; ) {
300			u8 bin;
301			sort_key_t key = *fence;
302
303			/*
304			 * The radix byte of the key tells us which pile it belongs in. Swap it for
305			 * an unprocessed item just below that pile, and repeat.
306			 */
307			while (--pile[bin = key[task.offset]] > fence)
308				swap_keys(pile[bin], &key);
309
310			/*
311			 * The pile reached the fence. Put the key at the bottom of that pile,
312			 * completing it, and advance the fence to the next pile.
313			 */
314			*fence = key;
315			fence += bins->size[bin];
316			bins->size[bin] = 0;
317		}
318
319		/* Now bins->size[] is all zero again. */
320
321		/*
322		 * When the number of keys in a task gets small enough, it is faster to use an
323		 * insertion sort than to keep subdividing into tiny piles.
324		 */
325		while (--insertion_task_list >= sorter->insertion_list)
326			insertion_sort(*insertion_task_list);
327	}
328
329	return UDS_SUCCESS;
330}
331