1/*	$NetBSD$	*/
2
3/*
4 * Copyright (C) 2001-2004 Sistina Software, Inc. All rights reserved.
5 * Copyright (C) 2004-2007 Red Hat, Inc. All rights reserved.
6 *
7 * This file is part of the device-mapper userspace tools.
8 *
9 * This copyrighted material is made available to anyone wishing to use,
10 * modify, copy, or redistribute it subject to the terms and conditions
11 * of the GNU Lesser General Public License v.2.1.
12 *
13 * You should have received a copy of the GNU Lesser General Public License
14 * along with this program; if not, write to the Free Software Foundation,
15 * Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16 */
17
18#include "dmlib.h"
19#include "parse_rx.h"
20#include "ttree.h"
21#include "assert.h"
22
23struct dfa_state {
24	int final;
25	struct dfa_state *lookup[256];
26};
27
28struct state_queue {
29	struct dfa_state *s;
30	dm_bitset_t bits;
31	struct state_queue *next;
32};
33
34struct dm_regex {		/* Instance variables for the lexer */
35	struct dfa_state *start;
36	unsigned num_nodes;
37	int nodes_entered;
38	struct rx_node **nodes;
39	struct dm_pool *scratch, *mem;
40};
41
42#define TARGET_TRANS '\0'
43
44static int _count_nodes(struct rx_node *rx)
45{
46	int r = 1;
47
48	if (rx->left)
49		r += _count_nodes(rx->left);
50
51	if (rx->right)
52		r += _count_nodes(rx->right);
53
54	return r;
55}
56
57static void _fill_table(struct dm_regex *m, struct rx_node *rx)
58{
59	assert((rx->type != OR) || (rx->left && rx->right));
60
61	if (rx->left)
62		_fill_table(m, rx->left);
63
64	if (rx->right)
65		_fill_table(m, rx->right);
66
67	m->nodes[m->nodes_entered++] = rx;
68}
69
70static void _create_bitsets(struct dm_regex *m)
71{
72	int i;
73
74	for (i = 0; i < m->num_nodes; i++) {
75		struct rx_node *n = m->nodes[i];
76		n->firstpos = dm_bitset_create(m->scratch, m->num_nodes);
77		n->lastpos = dm_bitset_create(m->scratch, m->num_nodes);
78		n->followpos = dm_bitset_create(m->scratch, m->num_nodes);
79	}
80}
81
82static void _calc_functions(struct dm_regex *m)
83{
84	int i, j, final = 1;
85	struct rx_node *rx, *c1, *c2;
86
87	for (i = 0; i < m->num_nodes; i++) {
88		rx = m->nodes[i];
89		c1 = rx->left;
90		c2 = rx->right;
91
92		if (dm_bit(rx->charset, TARGET_TRANS))
93			rx->final = final++;
94
95		switch (rx->type) {
96		case CAT:
97			if (c1->nullable)
98				dm_bit_union(rx->firstpos,
99					  c1->firstpos, c2->firstpos);
100			else
101				dm_bit_copy(rx->firstpos, c1->firstpos);
102
103			if (c2->nullable)
104				dm_bit_union(rx->lastpos,
105					  c1->lastpos, c2->lastpos);
106			else
107				dm_bit_copy(rx->lastpos, c2->lastpos);
108
109			rx->nullable = c1->nullable && c2->nullable;
110			break;
111
112		case PLUS:
113			dm_bit_copy(rx->firstpos, c1->firstpos);
114			dm_bit_copy(rx->lastpos, c1->lastpos);
115			rx->nullable = c1->nullable;
116			break;
117
118		case OR:
119			dm_bit_union(rx->firstpos, c1->firstpos, c2->firstpos);
120			dm_bit_union(rx->lastpos, c1->lastpos, c2->lastpos);
121			rx->nullable = c1->nullable || c2->nullable;
122			break;
123
124		case QUEST:
125		case STAR:
126			dm_bit_copy(rx->firstpos, c1->firstpos);
127			dm_bit_copy(rx->lastpos, c1->lastpos);
128			rx->nullable = 1;
129			break;
130
131		case CHARSET:
132			dm_bit_set(rx->firstpos, i);
133			dm_bit_set(rx->lastpos, i);
134			rx->nullable = 0;
135			break;
136
137		default:
138			log_error("Internal error: Unknown calc node type");
139		}
140
141		/*
142		 * followpos has it's own switch
143		 * because PLUS and STAR do the
144		 * same thing.
145		 */
146		switch (rx->type) {
147		case CAT:
148			for (j = 0; j < m->num_nodes; j++) {
149				if (dm_bit(c1->lastpos, j)) {
150					struct rx_node *n = m->nodes[j];
151					dm_bit_union(n->followpos,
152						  n->followpos, c2->firstpos);
153				}
154			}
155			break;
156
157		case PLUS:
158		case STAR:
159			for (j = 0; j < m->num_nodes; j++) {
160				if (dm_bit(rx->lastpos, j)) {
161					struct rx_node *n = m->nodes[j];
162					dm_bit_union(n->followpos,
163						  n->followpos, rx->firstpos);
164				}
165			}
166			break;
167		}
168	}
169}
170
171static struct dfa_state *_create_dfa_state(struct dm_pool *mem)
172{
173	return dm_pool_zalloc(mem, sizeof(struct dfa_state));
174}
175
176static struct state_queue *_create_state_queue(struct dm_pool *mem,
177					       struct dfa_state *dfa,
178					       dm_bitset_t bits)
179{
180	struct state_queue *r = dm_pool_alloc(mem, sizeof(*r));
181
182	if (!r) {
183		stack;
184		return NULL;
185	}
186
187	r->s = dfa;
188	r->bits = dm_bitset_create(mem, bits[0]);	/* first element is the size */
189	dm_bit_copy(r->bits, bits);
190	r->next = 0;
191	return r;
192}
193
194static int _calc_states(struct dm_regex *m, struct rx_node *rx)
195{
196	unsigned iwidth = (m->num_nodes / DM_BITS_PER_INT) + 1;
197	struct ttree *tt = ttree_create(m->scratch, iwidth);
198	struct state_queue *h, *t, *tmp;
199	struct dfa_state *dfa, *ldfa;
200	int i, a, set_bits = 0, count = 0;
201	dm_bitset_t bs, dfa_bits;
202
203	if (!tt)
204		return_0;
205
206	if (!(bs = dm_bitset_create(m->scratch, m->num_nodes)))
207		return_0;
208
209	/* create first state */
210	dfa = _create_dfa_state(m->mem);
211	m->start = dfa;
212	ttree_insert(tt, rx->firstpos + 1, dfa);
213
214	/* prime the queue */
215	h = t = _create_state_queue(m->scratch, dfa, rx->firstpos);
216	while (h) {
217		/* pop state off front of the queue */
218		dfa = h->s;
219		dfa_bits = h->bits;
220		h = h->next;
221
222		/* iterate through all the inputs for this state */
223		dm_bit_clear_all(bs);
224		for (a = 0; a < 256; a++) {
225			/* iterate through all the states in firstpos */
226			for (i = dm_bit_get_first(dfa_bits);
227			     i >= 0; i = dm_bit_get_next(dfa_bits, i)) {
228				if (dm_bit(m->nodes[i]->charset, a)) {
229					if (a == TARGET_TRANS)
230						dfa->final = m->nodes[i]->final;
231
232					dm_bit_union(bs, bs,
233						  m->nodes[i]->followpos);
234					set_bits = 1;
235				}
236			}
237
238			if (set_bits) {
239				ldfa = ttree_lookup(tt, bs + 1);
240				if (!ldfa) {
241					/* push */
242					ldfa = _create_dfa_state(m->mem);
243					ttree_insert(tt, bs + 1, ldfa);
244					tmp =
245					    _create_state_queue(m->scratch,
246								ldfa, bs);
247					if (!h)
248						h = t = tmp;
249					else {
250						t->next = tmp;
251						t = tmp;
252					}
253
254					count++;
255				}
256
257				dfa->lookup[a] = ldfa;
258				set_bits = 0;
259				dm_bit_clear_all(bs);
260			}
261		}
262	}
263
264	log_debug("Matcher built with %d dfa states", count);
265	return 1;
266}
267
268struct dm_regex *dm_regex_create(struct dm_pool *mem, const char **patterns,
269				 unsigned num_patterns)
270{
271	char *all, *ptr;
272	int i;
273	size_t len = 0;
274	struct rx_node *rx;
275	struct dm_pool *scratch = dm_pool_create("regex matcher", 10 * 1024);
276	struct dm_regex *m;
277
278	if (!scratch)
279		return_NULL;
280
281	if (!(m = dm_pool_alloc(mem, sizeof(*m)))) {
282		dm_pool_destroy(scratch);
283		return_NULL;
284	}
285
286	memset(m, 0, sizeof(*m));
287
288	/* join the regexps together, delimiting with zero */
289	for (i = 0; i < num_patterns; i++)
290		len += strlen(patterns[i]) + 8;
291
292	ptr = all = dm_pool_alloc(scratch, len + 1);
293
294	if (!all)
295		goto_bad;
296
297	for (i = 0; i < num_patterns; i++) {
298		ptr += sprintf(ptr, "(.*(%s)%c)", patterns[i], TARGET_TRANS);
299		if (i < (num_patterns - 1))
300			*ptr++ = '|';
301	}
302
303	/* parse this expression */
304	if (!(rx = rx_parse_tok(scratch, all, ptr))) {
305		log_error("Couldn't parse regex");
306		goto bad;
307	}
308
309	m->mem = mem;
310	m->scratch = scratch;
311	m->num_nodes = _count_nodes(rx);
312	m->nodes = dm_pool_alloc(scratch, sizeof(*m->nodes) * m->num_nodes);
313
314	if (!m->nodes)
315		goto_bad;
316
317	_fill_table(m, rx);
318	_create_bitsets(m);
319	_calc_functions(m);
320	_calc_states(m, rx);
321	dm_pool_destroy(scratch);
322	m->scratch = NULL;
323
324	return m;
325
326      bad:
327	dm_pool_destroy(scratch);
328	dm_pool_free(mem, m);
329	return NULL;
330}
331
332static struct dfa_state *_step_matcher(int c, struct dfa_state *cs, int *r)
333{
334	if (!(cs = cs->lookup[(unsigned char) c]))
335		return NULL;
336
337	if (cs->final && (cs->final > *r))
338		*r = cs->final;
339
340	return cs;
341}
342
343int dm_regex_match(struct dm_regex *regex, const char *s)
344{
345	struct dfa_state *cs = regex->start;
346	int r = 0;
347
348	if (!(cs = _step_matcher(HAT_CHAR, cs, &r)))
349		goto out;
350
351	for (; *s; s++)
352		if (!(cs = _step_matcher(*s, cs, &r)))
353			goto out;
354
355	_step_matcher(DOLLAR_CHAR, cs, &r);
356
357      out:
358	/* subtract 1 to get back to zero index */
359	return r - 1;
360}
361