1// SPDX-License-Identifier: GPL-2.0-only
2/* Copyright (C) 2008-2013 Jozsef Kadlecsik <kadlec@netfilter.org> */
3
4/* Kernel module implementing an IP set type: the list:set type */
5
6#include <linux/module.h>
7#include <linux/ip.h>
8#include <linux/rculist.h>
9#include <linux/skbuff.h>
10#include <linux/errno.h>
11
12#include <linux/netfilter/ipset/ip_set.h>
13#include <linux/netfilter/ipset/ip_set_list.h>
14
15#define IPSET_TYPE_REV_MIN	0
16/*				1    Counters support added */
17/*				2    Comments support added */
18#define IPSET_TYPE_REV_MAX	3 /* skbinfo support added */
19
20MODULE_LICENSE("GPL");
21MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@netfilter.org>");
22IP_SET_MODULE_DESC("list:set", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX);
23MODULE_ALIAS("ip_set_list:set");
24
25/* Member elements  */
26struct set_elem {
27	struct rcu_head rcu;
28	struct list_head list;
29	struct ip_set *set;	/* Sigh, in order to cleanup reference */
30	ip_set_id_t id;
31} __aligned(__alignof__(u64));
32
33struct set_adt_elem {
34	ip_set_id_t id;
35	ip_set_id_t refid;
36	int before;
37};
38
39/* Type structure */
40struct list_set {
41	u32 size;		/* size of set list array */
42	struct timer_list gc;	/* garbage collection */
43	struct ip_set *set;	/* attached to this ip_set */
44	struct net *net;	/* namespace */
45	struct list_head members; /* the set members */
46};
47
48static int
49list_set_ktest(struct ip_set *set, const struct sk_buff *skb,
50	       const struct xt_action_param *par,
51	       struct ip_set_adt_opt *opt, const struct ip_set_ext *ext)
52{
53	struct list_set *map = set->data;
54	struct ip_set_ext *mext = &opt->ext;
55	struct set_elem *e;
56	u32 flags = opt->cmdflags;
57	int ret;
58
59	/* Don't lookup sub-counters at all */
60	opt->cmdflags &= ~IPSET_FLAG_MATCH_COUNTERS;
61	if (opt->cmdflags & IPSET_FLAG_SKIP_SUBCOUNTER_UPDATE)
62		opt->cmdflags |= IPSET_FLAG_SKIP_COUNTER_UPDATE;
63	list_for_each_entry_rcu(e, &map->members, list) {
64		ret = ip_set_test(e->id, skb, par, opt);
65		if (ret <= 0)
66			continue;
67		if (ip_set_match_extensions(set, ext, mext, flags, e))
68			return 1;
69	}
70	return 0;
71}
72
73static int
74list_set_kadd(struct ip_set *set, const struct sk_buff *skb,
75	      const struct xt_action_param *par,
76	      struct ip_set_adt_opt *opt, const struct ip_set_ext *ext)
77{
78	struct list_set *map = set->data;
79	struct set_elem *e;
80	int ret;
81
82	list_for_each_entry(e, &map->members, list) {
83		if (SET_WITH_TIMEOUT(set) &&
84		    ip_set_timeout_expired(ext_timeout(e, set)))
85			continue;
86		ret = ip_set_add(e->id, skb, par, opt);
87		if (ret == 0)
88			return ret;
89	}
90	return 0;
91}
92
93static int
94list_set_kdel(struct ip_set *set, const struct sk_buff *skb,
95	      const struct xt_action_param *par,
96	      struct ip_set_adt_opt *opt, const struct ip_set_ext *ext)
97{
98	struct list_set *map = set->data;
99	struct set_elem *e;
100	int ret;
101
102	list_for_each_entry(e, &map->members, list) {
103		if (SET_WITH_TIMEOUT(set) &&
104		    ip_set_timeout_expired(ext_timeout(e, set)))
105			continue;
106		ret = ip_set_del(e->id, skb, par, opt);
107		if (ret == 0)
108			return ret;
109	}
110	return 0;
111}
112
113static int
114list_set_kadt(struct ip_set *set, const struct sk_buff *skb,
115	      const struct xt_action_param *par,
116	      enum ipset_adt adt, struct ip_set_adt_opt *opt)
117{
118	struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set);
119	int ret = -EINVAL;
120
121	rcu_read_lock();
122	switch (adt) {
123	case IPSET_TEST:
124		ret = list_set_ktest(set, skb, par, opt, &ext);
125		break;
126	case IPSET_ADD:
127		ret = list_set_kadd(set, skb, par, opt, &ext);
128		break;
129	case IPSET_DEL:
130		ret = list_set_kdel(set, skb, par, opt, &ext);
131		break;
132	default:
133		break;
134	}
135	rcu_read_unlock();
136
137	return ret;
138}
139
140/* Userspace interfaces: we are protected by the nfnl mutex */
141
142static void
143__list_set_del_rcu(struct rcu_head * rcu)
144{
145	struct set_elem *e = container_of(rcu, struct set_elem, rcu);
146	struct ip_set *set = e->set;
147
148	ip_set_ext_destroy(set, e);
149	kfree(e);
150}
151
152static void
153list_set_del(struct ip_set *set, struct set_elem *e)
154{
155	struct list_set *map = set->data;
156
157	set->elements--;
158	list_del_rcu(&e->list);
159	ip_set_put_byindex(map->net, e->id);
160	call_rcu(&e->rcu, __list_set_del_rcu);
161}
162
163static void
164list_set_replace(struct ip_set *set, struct set_elem *e, struct set_elem *old)
165{
166	struct list_set *map = set->data;
167
168	list_replace_rcu(&old->list, &e->list);
169	ip_set_put_byindex(map->net, old->id);
170	call_rcu(&old->rcu, __list_set_del_rcu);
171}
172
173static void
174set_cleanup_entries(struct ip_set *set)
175{
176	struct list_set *map = set->data;
177	struct set_elem *e, *n;
178
179	list_for_each_entry_safe(e, n, &map->members, list)
180		if (ip_set_timeout_expired(ext_timeout(e, set)))
181			list_set_del(set, e);
182}
183
184static int
185list_set_utest(struct ip_set *set, void *value, const struct ip_set_ext *ext,
186	       struct ip_set_ext *mext, u32 flags)
187{
188	struct list_set *map = set->data;
189	struct set_adt_elem *d = value;
190	struct set_elem *e, *next, *prev = NULL;
191	int ret;
192
193	list_for_each_entry(e, &map->members, list) {
194		if (SET_WITH_TIMEOUT(set) &&
195		    ip_set_timeout_expired(ext_timeout(e, set)))
196			continue;
197		else if (e->id != d->id) {
198			prev = e;
199			continue;
200		}
201
202		if (d->before == 0) {
203			ret = 1;
204		} else if (d->before > 0) {
205			next = list_next_entry(e, list);
206			ret = !list_is_last(&e->list, &map->members) &&
207			      next->id == d->refid;
208		} else {
209			ret = prev && prev->id == d->refid;
210		}
211		return ret;
212	}
213	return 0;
214}
215
216static void
217list_set_init_extensions(struct ip_set *set, const struct ip_set_ext *ext,
218			 struct set_elem *e)
219{
220	if (SET_WITH_COUNTER(set))
221		ip_set_init_counter(ext_counter(e, set), ext);
222	if (SET_WITH_COMMENT(set))
223		ip_set_init_comment(set, ext_comment(e, set), ext);
224	if (SET_WITH_SKBINFO(set))
225		ip_set_init_skbinfo(ext_skbinfo(e, set), ext);
226	/* Update timeout last */
227	if (SET_WITH_TIMEOUT(set))
228		ip_set_timeout_set(ext_timeout(e, set), ext->timeout);
229}
230
231static int
232list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext,
233	      struct ip_set_ext *mext, u32 flags)
234{
235	struct list_set *map = set->data;
236	struct set_adt_elem *d = value;
237	struct set_elem *e, *n, *prev, *next;
238	bool flag_exist = flags & IPSET_FLAG_EXIST;
239
240	/* Find where to add the new entry */
241	n = prev = next = NULL;
242	list_for_each_entry(e, &map->members, list) {
243		if (SET_WITH_TIMEOUT(set) &&
244		    ip_set_timeout_expired(ext_timeout(e, set)))
245			continue;
246		else if (d->id == e->id)
247			n = e;
248		else if (d->before == 0 || e->id != d->refid)
249			continue;
250		else if (d->before > 0)
251			next = e;
252		else
253			prev = e;
254	}
255
256	/* If before/after is used on an empty set */
257	if ((d->before > 0 && !next) ||
258	    (d->before < 0 && !prev))
259		return -IPSET_ERR_REF_EXIST;
260
261	/* Re-add already existing element */
262	if (n) {
263		if (!flag_exist)
264			return -IPSET_ERR_EXIST;
265		/* Update extensions */
266		ip_set_ext_destroy(set, n);
267		list_set_init_extensions(set, ext, n);
268
269		/* Set is already added to the list */
270		ip_set_put_byindex(map->net, d->id);
271		return 0;
272	}
273	/* Add new entry */
274	if (d->before == 0) {
275		/* Append  */
276		n = list_empty(&map->members) ? NULL :
277		    list_last_entry(&map->members, struct set_elem, list);
278	} else if (d->before > 0) {
279		/* Insert after next element */
280		if (!list_is_last(&next->list, &map->members))
281			n = list_next_entry(next, list);
282	} else {
283		/* Insert before prev element */
284		if (prev->list.prev != &map->members)
285			n = list_prev_entry(prev, list);
286	}
287	/* Can we replace a timed out entry? */
288	if (n &&
289	    !(SET_WITH_TIMEOUT(set) &&
290	      ip_set_timeout_expired(ext_timeout(n, set))))
291		n = NULL;
292
293	e = kzalloc(set->dsize, GFP_ATOMIC);
294	if (!e)
295		return -ENOMEM;
296	e->id = d->id;
297	e->set = set;
298	INIT_LIST_HEAD(&e->list);
299	list_set_init_extensions(set, ext, e);
300	if (n)
301		list_set_replace(set, e, n);
302	else if (next)
303		list_add_tail_rcu(&e->list, &next->list);
304	else if (prev)
305		list_add_rcu(&e->list, &prev->list);
306	else
307		list_add_tail_rcu(&e->list, &map->members);
308	set->elements++;
309
310	return 0;
311}
312
313static int
314list_set_udel(struct ip_set *set, void *value, const struct ip_set_ext *ext,
315	      struct ip_set_ext *mext, u32 flags)
316{
317	struct list_set *map = set->data;
318	struct set_adt_elem *d = value;
319	struct set_elem *e, *next, *prev = NULL;
320
321	list_for_each_entry(e, &map->members, list) {
322		if (SET_WITH_TIMEOUT(set) &&
323		    ip_set_timeout_expired(ext_timeout(e, set)))
324			continue;
325		else if (e->id != d->id) {
326			prev = e;
327			continue;
328		}
329
330		if (d->before > 0) {
331			next = list_next_entry(e, list);
332			if (list_is_last(&e->list, &map->members) ||
333			    next->id != d->refid)
334				return -IPSET_ERR_REF_EXIST;
335		} else if (d->before < 0) {
336			if (!prev || prev->id != d->refid)
337				return -IPSET_ERR_REF_EXIST;
338		}
339		list_set_del(set, e);
340		return 0;
341	}
342	return d->before != 0 ? -IPSET_ERR_REF_EXIST : -IPSET_ERR_EXIST;
343}
344
345static int
346list_set_uadt(struct ip_set *set, struct nlattr *tb[],
347	      enum ipset_adt adt, u32 *lineno, u32 flags, bool retried)
348{
349	struct list_set *map = set->data;
350	ipset_adtfn adtfn = set->variant->adt[adt];
351	struct set_adt_elem e = { .refid = IPSET_INVALID_ID };
352	struct ip_set_ext ext = IP_SET_INIT_UEXT(set);
353	struct ip_set *s;
354	int ret = 0;
355
356	if (tb[IPSET_ATTR_LINENO])
357		*lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]);
358
359	if (unlikely(!tb[IPSET_ATTR_NAME] ||
360		     !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS)))
361		return -IPSET_ERR_PROTOCOL;
362
363	ret = ip_set_get_extensions(set, tb, &ext);
364	if (ret)
365		return ret;
366	e.id = ip_set_get_byname(map->net, nla_data(tb[IPSET_ATTR_NAME]), &s);
367	if (e.id == IPSET_INVALID_ID)
368		return -IPSET_ERR_NAME;
369	/* "Loop detection" */
370	if (s->type->features & IPSET_TYPE_NAME) {
371		ret = -IPSET_ERR_LOOP;
372		goto finish;
373	}
374
375	if (tb[IPSET_ATTR_CADT_FLAGS]) {
376		u32 f = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
377
378		e.before = f & IPSET_FLAG_BEFORE;
379	}
380
381	if (e.before && !tb[IPSET_ATTR_NAMEREF]) {
382		ret = -IPSET_ERR_BEFORE;
383		goto finish;
384	}
385
386	if (tb[IPSET_ATTR_NAMEREF]) {
387		e.refid = ip_set_get_byname(map->net,
388					    nla_data(tb[IPSET_ATTR_NAMEREF]),
389					    &s);
390		if (e.refid == IPSET_INVALID_ID) {
391			ret = -IPSET_ERR_NAMEREF;
392			goto finish;
393		}
394		if (!e.before)
395			e.before = -1;
396	}
397	if (adt != IPSET_TEST && SET_WITH_TIMEOUT(set))
398		set_cleanup_entries(set);
399
400	ret = adtfn(set, &e, &ext, &ext, flags);
401
402finish:
403	if (e.refid != IPSET_INVALID_ID)
404		ip_set_put_byindex(map->net, e.refid);
405	if (adt != IPSET_ADD || ret)
406		ip_set_put_byindex(map->net, e.id);
407
408	return ip_set_eexist(ret, flags) ? 0 : ret;
409}
410
411static void
412list_set_flush(struct ip_set *set)
413{
414	struct list_set *map = set->data;
415	struct set_elem *e, *n;
416
417	list_for_each_entry_safe(e, n, &map->members, list)
418		list_set_del(set, e);
419	set->elements = 0;
420	set->ext_size = 0;
421}
422
423static void
424list_set_destroy(struct ip_set *set)
425{
426	struct list_set *map = set->data;
427	struct set_elem *e, *n;
428
429	list_for_each_entry_safe(e, n, &map->members, list) {
430		list_del(&e->list);
431		ip_set_put_byindex(map->net, e->id);
432		ip_set_ext_destroy(set, e);
433		kfree(e);
434	}
435	kfree(map);
436
437	set->data = NULL;
438}
439
440/* Calculate the actual memory size of the set data */
441static size_t
442list_set_memsize(const struct list_set *map, size_t dsize)
443{
444	struct set_elem *e;
445	u32 n = 0;
446
447	rcu_read_lock();
448	list_for_each_entry_rcu(e, &map->members, list)
449		n++;
450	rcu_read_unlock();
451
452	return (sizeof(*map) + n * dsize);
453}
454
455static int
456list_set_head(struct ip_set *set, struct sk_buff *skb)
457{
458	const struct list_set *map = set->data;
459	struct nlattr *nested;
460	size_t memsize = list_set_memsize(map, set->dsize) + set->ext_size;
461
462	nested = nla_nest_start(skb, IPSET_ATTR_DATA);
463	if (!nested)
464		goto nla_put_failure;
465	if (nla_put_net32(skb, IPSET_ATTR_SIZE, htonl(map->size)) ||
466	    nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) ||
467	    nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) ||
468	    nla_put_net32(skb, IPSET_ATTR_ELEMENTS, htonl(set->elements)))
469		goto nla_put_failure;
470	if (unlikely(ip_set_put_flags(skb, set)))
471		goto nla_put_failure;
472	nla_nest_end(skb, nested);
473
474	return 0;
475nla_put_failure:
476	return -EMSGSIZE;
477}
478
479static int
480list_set_list(const struct ip_set *set,
481	      struct sk_buff *skb, struct netlink_callback *cb)
482{
483	const struct list_set *map = set->data;
484	struct nlattr *atd, *nested;
485	u32 i = 0, first = cb->args[IPSET_CB_ARG0];
486	char name[IPSET_MAXNAMELEN];
487	struct set_elem *e;
488	int ret = 0;
489
490	atd = nla_nest_start(skb, IPSET_ATTR_ADT);
491	if (!atd)
492		return -EMSGSIZE;
493
494	rcu_read_lock();
495	list_for_each_entry_rcu(e, &map->members, list) {
496		if (i < first ||
497		    (SET_WITH_TIMEOUT(set) &&
498		     ip_set_timeout_expired(ext_timeout(e, set)))) {
499			i++;
500			continue;
501		}
502		nested = nla_nest_start(skb, IPSET_ATTR_DATA);
503		if (!nested)
504			goto nla_put_failure;
505		ip_set_name_byindex(map->net, e->id, name);
506		if (nla_put_string(skb, IPSET_ATTR_NAME, name))
507			goto nla_put_failure;
508		if (ip_set_put_extensions(skb, set, e, true))
509			goto nla_put_failure;
510		nla_nest_end(skb, nested);
511		i++;
512	}
513
514	nla_nest_end(skb, atd);
515	/* Set listing finished */
516	cb->args[IPSET_CB_ARG0] = 0;
517	goto out;
518
519nla_put_failure:
520	nla_nest_cancel(skb, nested);
521	if (unlikely(i == first)) {
522		nla_nest_cancel(skb, atd);
523		cb->args[IPSET_CB_ARG0] = 0;
524		ret = -EMSGSIZE;
525	} else {
526		cb->args[IPSET_CB_ARG0] = i;
527		nla_nest_end(skb, atd);
528	}
529out:
530	rcu_read_unlock();
531	return ret;
532}
533
534static bool
535list_set_same_set(const struct ip_set *a, const struct ip_set *b)
536{
537	const struct list_set *x = a->data;
538	const struct list_set *y = b->data;
539
540	return x->size == y->size &&
541	       a->timeout == b->timeout &&
542	       a->extensions == b->extensions;
543}
544
545static void
546list_set_cancel_gc(struct ip_set *set)
547{
548	struct list_set *map = set->data;
549
550	if (SET_WITH_TIMEOUT(set))
551		timer_shutdown_sync(&map->gc);
552}
553
554static const struct ip_set_type_variant set_variant = {
555	.kadt	= list_set_kadt,
556	.uadt	= list_set_uadt,
557	.adt	= {
558		[IPSET_ADD] = list_set_uadd,
559		[IPSET_DEL] = list_set_udel,
560		[IPSET_TEST] = list_set_utest,
561	},
562	.destroy = list_set_destroy,
563	.flush	= list_set_flush,
564	.head	= list_set_head,
565	.list	= list_set_list,
566	.same_set = list_set_same_set,
567	.cancel_gc = list_set_cancel_gc,
568};
569
570static void
571list_set_gc(struct timer_list *t)
572{
573	struct list_set *map = from_timer(map, t, gc);
574	struct ip_set *set = map->set;
575
576	spin_lock_bh(&set->lock);
577	set_cleanup_entries(set);
578	spin_unlock_bh(&set->lock);
579
580	map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ;
581	add_timer(&map->gc);
582}
583
584static void
585list_set_gc_init(struct ip_set *set, void (*gc)(struct timer_list *t))
586{
587	struct list_set *map = set->data;
588
589	timer_setup(&map->gc, gc, 0);
590	mod_timer(&map->gc, jiffies + IPSET_GC_PERIOD(set->timeout) * HZ);
591}
592
593/* Create list:set type of sets */
594
595static bool
596init_list_set(struct net *net, struct ip_set *set, u32 size)
597{
598	struct list_set *map;
599
600	map = kzalloc(sizeof(*map), GFP_KERNEL);
601	if (!map)
602		return false;
603
604	map->size = size;
605	map->net = net;
606	map->set = set;
607	INIT_LIST_HEAD(&map->members);
608	set->data = map;
609
610	return true;
611}
612
613static int
614list_set_create(struct net *net, struct ip_set *set, struct nlattr *tb[],
615		u32 flags)
616{
617	u32 size = IP_SET_LIST_DEFAULT_SIZE;
618
619	if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_SIZE) ||
620		     !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
621		     !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS)))
622		return -IPSET_ERR_PROTOCOL;
623
624	if (tb[IPSET_ATTR_SIZE])
625		size = ip_set_get_h32(tb[IPSET_ATTR_SIZE]);
626	if (size < IP_SET_LIST_MIN_SIZE)
627		size = IP_SET_LIST_MIN_SIZE;
628
629	set->variant = &set_variant;
630	set->dsize = ip_set_elem_len(set, tb, sizeof(struct set_elem),
631				     __alignof__(struct set_elem));
632	if (!init_list_set(net, set, size))
633		return -ENOMEM;
634	if (tb[IPSET_ATTR_TIMEOUT]) {
635		set->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
636		list_set_gc_init(set, list_set_gc);
637	}
638	return 0;
639}
640
641static struct ip_set_type list_set_type __read_mostly = {
642	.name		= "list:set",
643	.protocol	= IPSET_PROTOCOL,
644	.features	= IPSET_TYPE_NAME | IPSET_DUMP_LAST,
645	.dimension	= IPSET_DIM_ONE,
646	.family		= NFPROTO_UNSPEC,
647	.revision_min	= IPSET_TYPE_REV_MIN,
648	.revision_max	= IPSET_TYPE_REV_MAX,
649	.create		= list_set_create,
650	.create_policy	= {
651		[IPSET_ATTR_SIZE]	= { .type = NLA_U32 },
652		[IPSET_ATTR_TIMEOUT]	= { .type = NLA_U32 },
653		[IPSET_ATTR_CADT_FLAGS]	= { .type = NLA_U32 },
654	},
655	.adt_policy	= {
656		[IPSET_ATTR_NAME]	= { .type = NLA_STRING,
657					    .len = IPSET_MAXNAMELEN },
658		[IPSET_ATTR_NAMEREF]	= { .type = NLA_STRING,
659					    .len = IPSET_MAXNAMELEN },
660		[IPSET_ATTR_TIMEOUT]	= { .type = NLA_U32 },
661		[IPSET_ATTR_LINENO]	= { .type = NLA_U32 },
662		[IPSET_ATTR_CADT_FLAGS]	= { .type = NLA_U32 },
663		[IPSET_ATTR_BYTES]	= { .type = NLA_U64 },
664		[IPSET_ATTR_PACKETS]	= { .type = NLA_U64 },
665		[IPSET_ATTR_COMMENT]	= { .type = NLA_NUL_STRING,
666					    .len  = IPSET_MAX_COMMENT_SIZE },
667		[IPSET_ATTR_SKBMARK]	= { .type = NLA_U64 },
668		[IPSET_ATTR_SKBPRIO]	= { .type = NLA_U32 },
669		[IPSET_ATTR_SKBQUEUE]	= { .type = NLA_U16 },
670	},
671	.me		= THIS_MODULE,
672};
673
674static int __init
675list_set_init(void)
676{
677	return ip_set_type_register(&list_set_type);
678}
679
680static void __exit
681list_set_fini(void)
682{
683	rcu_barrier();
684	ip_set_type_unregister(&list_set_type);
685}
686
687module_init(list_set_init);
688module_exit(list_set_fini);
689