1/* SPDX-License-Identifier: GPL-2.0-only */
2/* Copyright (C) 2013 Jozsef Kadlecsik <kadlec@netfilter.org> */
3
4#ifndef __IP_SET_BITMAP_IP_GEN_H
5#define __IP_SET_BITMAP_IP_GEN_H
6
7#include <linux/rcupdate_wait.h>
8
9#define mtype_do_test		IPSET_TOKEN(MTYPE, _do_test)
10#define mtype_gc_test		IPSET_TOKEN(MTYPE, _gc_test)
11#define mtype_is_filled		IPSET_TOKEN(MTYPE, _is_filled)
12#define mtype_do_add		IPSET_TOKEN(MTYPE, _do_add)
13#define mtype_ext_cleanup	IPSET_TOKEN(MTYPE, _ext_cleanup)
14#define mtype_do_del		IPSET_TOKEN(MTYPE, _do_del)
15#define mtype_do_list		IPSET_TOKEN(MTYPE, _do_list)
16#define mtype_do_head		IPSET_TOKEN(MTYPE, _do_head)
17#define mtype_adt_elem		IPSET_TOKEN(MTYPE, _adt_elem)
18#define mtype_add_timeout	IPSET_TOKEN(MTYPE, _add_timeout)
19#define mtype_gc_init		IPSET_TOKEN(MTYPE, _gc_init)
20#define mtype_kadt		IPSET_TOKEN(MTYPE, _kadt)
21#define mtype_uadt		IPSET_TOKEN(MTYPE, _uadt)
22#define mtype_destroy		IPSET_TOKEN(MTYPE, _destroy)
23#define mtype_memsize		IPSET_TOKEN(MTYPE, _memsize)
24#define mtype_flush		IPSET_TOKEN(MTYPE, _flush)
25#define mtype_head		IPSET_TOKEN(MTYPE, _head)
26#define mtype_same_set		IPSET_TOKEN(MTYPE, _same_set)
27#define mtype_elem		IPSET_TOKEN(MTYPE, _elem)
28#define mtype_test		IPSET_TOKEN(MTYPE, _test)
29#define mtype_add		IPSET_TOKEN(MTYPE, _add)
30#define mtype_del		IPSET_TOKEN(MTYPE, _del)
31#define mtype_list		IPSET_TOKEN(MTYPE, _list)
32#define mtype_gc		IPSET_TOKEN(MTYPE, _gc)
33#define mtype_cancel_gc		IPSET_TOKEN(MTYPE, _cancel_gc)
34#define mtype			MTYPE
35
36#define get_ext(set, map, id)	((map)->extensions + ((set)->dsize * (id)))
37
38static void
39mtype_gc_init(struct ip_set *set, void (*gc)(struct timer_list *t))
40{
41	struct mtype *map = set->data;
42
43	timer_setup(&map->gc, gc, 0);
44	mod_timer(&map->gc, jiffies + IPSET_GC_PERIOD(set->timeout) * HZ);
45}
46
47static void
48mtype_ext_cleanup(struct ip_set *set)
49{
50	struct mtype *map = set->data;
51	u32 id;
52
53	for (id = 0; id < map->elements; id++)
54		if (test_bit(id, map->members))
55			ip_set_ext_destroy(set, get_ext(set, map, id));
56}
57
58static void
59mtype_destroy(struct ip_set *set)
60{
61	struct mtype *map = set->data;
62
63	if (set->dsize && set->extensions & IPSET_EXT_DESTROY)
64		mtype_ext_cleanup(set);
65	ip_set_free(map->members);
66	ip_set_free(map);
67
68	set->data = NULL;
69}
70
71static void
72mtype_flush(struct ip_set *set)
73{
74	struct mtype *map = set->data;
75
76	if (set->extensions & IPSET_EXT_DESTROY)
77		mtype_ext_cleanup(set);
78	bitmap_zero(map->members, map->elements);
79	set->elements = 0;
80	set->ext_size = 0;
81}
82
83/* Calculate the actual memory size of the set data */
84static size_t
85mtype_memsize(const struct mtype *map, size_t dsize)
86{
87	return sizeof(*map) + map->memsize +
88	       map->elements * dsize;
89}
90
91static int
92mtype_head(struct ip_set *set, struct sk_buff *skb)
93{
94	const struct mtype *map = set->data;
95	struct nlattr *nested;
96	size_t memsize = mtype_memsize(map, set->dsize) + set->ext_size;
97
98	nested = nla_nest_start(skb, IPSET_ATTR_DATA);
99	if (!nested)
100		goto nla_put_failure;
101	if (mtype_do_head(skb, map) ||
102	    nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) ||
103	    nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) ||
104	    nla_put_net32(skb, IPSET_ATTR_ELEMENTS, htonl(set->elements)))
105		goto nla_put_failure;
106	if (unlikely(ip_set_put_flags(skb, set)))
107		goto nla_put_failure;
108	nla_nest_end(skb, nested);
109
110	return 0;
111nla_put_failure:
112	return -EMSGSIZE;
113}
114
115static int
116mtype_test(struct ip_set *set, void *value, const struct ip_set_ext *ext,
117	   struct ip_set_ext *mext, u32 flags)
118{
119	struct mtype *map = set->data;
120	const struct mtype_adt_elem *e = value;
121	void *x = get_ext(set, map, e->id);
122	int ret = mtype_do_test(e, map, set->dsize);
123
124	if (ret <= 0)
125		return ret;
126	return ip_set_match_extensions(set, ext, mext, flags, x);
127}
128
129static int
130mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
131	  struct ip_set_ext *mext, u32 flags)
132{
133	struct mtype *map = set->data;
134	const struct mtype_adt_elem *e = value;
135	void *x = get_ext(set, map, e->id);
136	int ret = mtype_do_add(e, map, flags, set->dsize);
137
138	if (ret == IPSET_ADD_FAILED) {
139		if (SET_WITH_TIMEOUT(set) &&
140		    ip_set_timeout_expired(ext_timeout(x, set))) {
141			set->elements--;
142			ret = 0;
143		} else if (!(flags & IPSET_FLAG_EXIST)) {
144			set_bit(e->id, map->members);
145			return -IPSET_ERR_EXIST;
146		}
147		/* Element is re-added, cleanup extensions */
148		ip_set_ext_destroy(set, x);
149	}
150	if (ret > 0)
151		set->elements--;
152
153	if (SET_WITH_TIMEOUT(set))
154#ifdef IP_SET_BITMAP_STORED_TIMEOUT
155		mtype_add_timeout(ext_timeout(x, set), e, ext, set, map, ret);
156#else
157		ip_set_timeout_set(ext_timeout(x, set), ext->timeout);
158#endif
159
160	if (SET_WITH_COUNTER(set))
161		ip_set_init_counter(ext_counter(x, set), ext);
162	if (SET_WITH_COMMENT(set))
163		ip_set_init_comment(set, ext_comment(x, set), ext);
164	if (SET_WITH_SKBINFO(set))
165		ip_set_init_skbinfo(ext_skbinfo(x, set), ext);
166
167	/* Activate element */
168	set_bit(e->id, map->members);
169	set->elements++;
170
171	return 0;
172}
173
174static int
175mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext,
176	  struct ip_set_ext *mext, u32 flags)
177{
178	struct mtype *map = set->data;
179	const struct mtype_adt_elem *e = value;
180	void *x = get_ext(set, map, e->id);
181
182	if (mtype_do_del(e, map))
183		return -IPSET_ERR_EXIST;
184
185	ip_set_ext_destroy(set, x);
186	set->elements--;
187	if (SET_WITH_TIMEOUT(set) &&
188	    ip_set_timeout_expired(ext_timeout(x, set)))
189		return -IPSET_ERR_EXIST;
190
191	return 0;
192}
193
194#ifndef IP_SET_BITMAP_STORED_TIMEOUT
195static bool
196mtype_is_filled(const struct mtype_elem *x)
197{
198	return true;
199}
200#endif
201
202static int
203mtype_list(const struct ip_set *set,
204	   struct sk_buff *skb, struct netlink_callback *cb)
205{
206	struct mtype *map = set->data;
207	struct nlattr *adt, *nested;
208	void *x;
209	u32 id, first = cb->args[IPSET_CB_ARG0];
210	int ret = 0;
211
212	adt = nla_nest_start(skb, IPSET_ATTR_ADT);
213	if (!adt)
214		return -EMSGSIZE;
215	/* Extensions may be replaced */
216	rcu_read_lock();
217	for (; cb->args[IPSET_CB_ARG0] < map->elements;
218	     cb->args[IPSET_CB_ARG0]++) {
219		cond_resched_rcu();
220		id = cb->args[IPSET_CB_ARG0];
221		x = get_ext(set, map, id);
222		if (!test_bit(id, map->members) ||
223		    (SET_WITH_TIMEOUT(set) &&
224#ifdef IP_SET_BITMAP_STORED_TIMEOUT
225		     mtype_is_filled(x) &&
226#endif
227		     ip_set_timeout_expired(ext_timeout(x, set))))
228			continue;
229		nested = nla_nest_start(skb, IPSET_ATTR_DATA);
230		if (!nested) {
231			if (id == first) {
232				nla_nest_cancel(skb, adt);
233				ret = -EMSGSIZE;
234				goto out;
235			}
236
237			goto nla_put_failure;
238		}
239		if (mtype_do_list(skb, map, id, set->dsize))
240			goto nla_put_failure;
241		if (ip_set_put_extensions(skb, set, x, mtype_is_filled(x)))
242			goto nla_put_failure;
243		nla_nest_end(skb, nested);
244	}
245	nla_nest_end(skb, adt);
246
247	/* Set listing finished */
248	cb->args[IPSET_CB_ARG0] = 0;
249
250	goto out;
251
252nla_put_failure:
253	nla_nest_cancel(skb, nested);
254	if (unlikely(id == first)) {
255		cb->args[IPSET_CB_ARG0] = 0;
256		ret = -EMSGSIZE;
257	}
258	nla_nest_end(skb, adt);
259out:
260	rcu_read_unlock();
261	return ret;
262}
263
264static void
265mtype_gc(struct timer_list *t)
266{
267	struct mtype *map = from_timer(map, t, gc);
268	struct ip_set *set = map->set;
269	void *x;
270	u32 id;
271
272	/* We run parallel with other readers (test element)
273	 * but adding/deleting new entries is locked out
274	 */
275	spin_lock_bh(&set->lock);
276	for (id = 0; id < map->elements; id++)
277		if (mtype_gc_test(id, map, set->dsize)) {
278			x = get_ext(set, map, id);
279			if (ip_set_timeout_expired(ext_timeout(x, set))) {
280				clear_bit(id, map->members);
281				ip_set_ext_destroy(set, x);
282				set->elements--;
283			}
284		}
285	spin_unlock_bh(&set->lock);
286
287	map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ;
288	add_timer(&map->gc);
289}
290
291static void
292mtype_cancel_gc(struct ip_set *set)
293{
294	struct mtype *map = set->data;
295
296	if (SET_WITH_TIMEOUT(set))
297		del_timer_sync(&map->gc);
298}
299
300static const struct ip_set_type_variant mtype = {
301	.kadt	= mtype_kadt,
302	.uadt	= mtype_uadt,
303	.adt	= {
304		[IPSET_ADD] = mtype_add,
305		[IPSET_DEL] = mtype_del,
306		[IPSET_TEST] = mtype_test,
307	},
308	.destroy = mtype_destroy,
309	.flush	= mtype_flush,
310	.head	= mtype_head,
311	.list	= mtype_list,
312	.same_set = mtype_same_set,
313	.cancel_gc = mtype_cancel_gc,
314};
315
316#endif /* __IP_SET_BITMAP_IP_GEN_H */
317