1/*
2 * Ethernet Switch IGMP Snooper
3 * Copyright (C) 2014 ASUSTeK Inc.
4 * All Rights Reserved.
5
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License as
8 * published by the Free Software Foundation, either version 3 of the
9 * License, or (at your option) any later version.
10
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU Affero General Public License for more details.
15
16 * You should have received a copy of the GNU Affero General Public License
17 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 */
19
20#include <stdio.h>
21#include <stdint.h>
22#include <stdlib.h>
23#include <stddef.h>
24#include <string.h>
25#include <netinet/ether.h>
26
27#include "snooper.h"
28#include "queue.h"
29
30#ifdef DEBUG_CACHE
31#define log_cache(fmt, args...) log_debug("%s::" fmt, "cache", ##args)
32#else
33#define log_cache(...) do {} while (0)
34#endif
35
36#define GROUP_POOL_SIZE 512
37#define MEMBER_POOL_SIZE 1024
38#define HOST_POOL_SIZE 32
39#define HOST_TTL 3
40
41#define HASH_SIZE 64
42#define HASH_INDEX(ea) (ether_hash(ea) % HASH_SIZE)
43
44#undef HOSTPOOL_STATIC
45#undef GROUP_POOL_STATIC
46
47struct host_entry {
48	STAILQ_ENTRY(host_entry) link;
49	LIST_ENTRY(host_entry) hash;
50	unsigned long time;
51	int port;
52	unsigned char ea[ETHER_ADDR_LEN];
53};
54static struct {
55	STAILQ_HEAD(, host_entry) pool;
56	LIST_HEAD(, host_entry) hash[HASH_SIZE];
57	int count;
58#ifdef HOST_POOL_STATIC
59	struct host_entry entries[HOST_POOL_SIZE];
60#endif
61} hosts;
62
63struct member_entry {
64	LIST_ENTRY(member_entry) link;
65	unsigned long time;
66	in_addr_t addr;
67};
68static struct {
69	LIST_HEAD(, member_entry) free;
70	int count;
71} members;
72
73struct group_entry {
74	STAILQ_ENTRY(group_entry) link;
75	LIST_ENTRY(group_entry) hash;
76	LIST_HEAD(, member_entry) members[PORT_MAX + 1];
77	unsigned long time;
78	int portmap;
79	unsigned char ea[ETHER_ADDR_LEN];
80};
81static struct {
82	STAILQ_HEAD(, group_entry) pool;
83	LIST_HEAD(, group_entry) hash[HASH_SIZE];
84	struct timer_entry timer;
85	int count;
86#ifdef GROUP_POOL_STATIC
87	struct group_entry entries[GROUP_POOL_SIZE];
88#endif
89} groups;
90static struct {
91	struct group_entry group;
92	struct timer_entry timer;
93} routers;
94
95static void group_timer(struct timer_entry *timer, void *data);
96static void router_timer(struct timer_entry *timer, void *data);
97
98static struct host_entry *get_host(unsigned char *ea, unsigned long time)
99{
100	struct host_entry *host, *prev;
101	int index = HASH_INDEX(ea);
102
103	LIST_FOREACH(host, &hosts.hash[index], hash) {
104		if (memcmp(host->ea, ea, ETHER_ADDR_LEN) == 0)
105			return host;
106	}
107
108	if (hosts.count < HOST_POOL_SIZE) {
109#ifdef HOST_POOL_STATIC
110		host = &hosts.entries[hosts.count++];
111#else
112		host = calloc(1, sizeof(*host));
113		if (!host)
114			hosts.count++;
115#endif
116	}
117	if (!host) {
118		prev = NULL;
119		STAILQ_FOREACH(host, &hosts.pool, link) {
120			if (time_before(host->time, time))
121				break;
122			prev = host;
123		}
124		if (!host)
125			return NULL;
126		LIST_REMOVE(host, hash);
127		if (prev)
128			STAILQ_REMOVE_AFTER(&hosts.pool, prev, link);
129		else	STAILQ_REMOVE_HEAD(&hosts.pool, link);
130		memset(&host, 0, sizeof(*host));
131	}
132
133	memcpy(host->ea, ea, ETHER_ADDR_LEN);
134	LIST_INSERT_HEAD(&hosts.hash[index], host, hash);
135	STAILQ_INSERT_TAIL(&hosts.pool, host, link);
136
137	return host;
138}
139
140int get_port(unsigned char *haddr)
141{
142	struct host_entry *host;
143	unsigned long time = now();
144	int port;
145
146	host = get_host(haddr, time);
147	if (host && time_after_eq(host->time, time)) {
148		log_cache("%-6s [" FMT_EA "] = " FMT_PORTS, "port",
149		    ARG_EA(haddr), ARG_PORTS((host->port < 0) ? -1 : 1 << host->port));
150		return host->port;
151	}
152
153	port = switch_get_port(haddr);
154	log_cache("%-6s [" FMT_EA "] = " FMT_PORTS, "read",
155	    ARG_EA(haddr), ARG_PORTS((port < 0) ? -1 : 1 << port));
156
157	if (host && 0 <= port && port <= PORT_MAX) {
158		host->port = port;
159		host->time = time + HOST_TTL * TIMER_HZ;
160	}
161
162	return port;
163}
164
165static struct member_entry *get_member(struct group_entry *group, in_addr_t addr, int port, int allocate)
166{
167	struct member_entry *member;
168
169	LIST_FOREACH(member, &group->members[port], link) {
170		if (member->addr == addr)
171			return member;
172	}
173	if (!allocate)
174		return NULL;
175
176	member = LIST_FIRST(&members.free);
177	if (member) {
178		LIST_REMOVE(member, link);
179		memset(member, 0, sizeof(*member));
180	} else
181	if (members.count < MEMBER_POOL_SIZE) {
182		member = calloc(1, sizeof(*member));
183		if (member)
184			members.count++;
185	}
186	if (!member)
187		return NULL;
188
189//	member->time = now();
190	member->addr = addr;
191	LIST_INSERT_HEAD(&group->members[port], member, link);
192
193	return member;
194}
195
196static void consume_member(struct member_entry *member)
197{
198	LIST_REMOVE(member, link);
199	LIST_INSERT_HEAD(&members.free, member, link);
200}
201
202static void init_group(struct group_entry *group)
203{
204	int port;
205
206//	group->time = now();
207	group->portmap = 0;
208	for (port = 0; port <= PORT_MAX; port++)
209		LIST_INIT(&group->members[port]);
210}
211
212static struct group_entry *get_group(unsigned char *ea, int allocate)
213{
214	struct group_entry *group, *prev;
215	int index = HASH_INDEX(ea);
216
217	LIST_FOREACH(group, &groups.hash[index], hash) {
218		if (memcmp(group->ea, ea, ETHER_ADDR_LEN) == 0)
219			return group;
220	}
221	if (!allocate)
222		return NULL;
223
224	if (groups.count < GROUP_POOL_SIZE) {
225#ifdef GROUP_POOL_STATIC
226		group = &groups.entries[groups.count++];
227#else
228		group = calloc(1, sizeof(*group));
229		if (group)
230			groups.count++;
231#endif
232	}
233	if (!group) {
234		prev = NULL;
235		STAILQ_FOREACH(group, &groups.pool, link) {
236			if (group->portmap == 0)
237				break;
238			prev = group;
239		}
240		if (!group)
241			return NULL;
242		LIST_REMOVE(group, hash);
243		if (prev)
244			STAILQ_REMOVE_AFTER(&groups.pool, prev, link);
245		else 	STAILQ_REMOVE_HEAD(&groups.pool, link);
246		switch_clr_portmap(group->ea);
247		memset(&group, 0, sizeof(group));
248	}
249
250	init_group(group);
251	memcpy(group->ea, ea, ETHER_ADDR_LEN);
252	LIST_INSERT_HEAD(&groups.hash[index], group, hash);
253	STAILQ_INSERT_TAIL(&groups.pool, group, link);
254
255	return group;
256}
257
258static void consume_group(struct group_entry *group)
259{
260	struct member_entry *member, *next;
261	int port;
262
263	group->portmap = 0;
264	for (port = 0; port <= PORT_MAX; port++) {
265		LIST_FOREACH_SAFE(member, &group->members[port], link, next)
266			consume_member(member);
267	}
268}
269
270static int get_portmap(struct group_entry *group)
271{
272	int port, portmap = 0;
273
274	for (port = 0; port <= PORT_MAX; port++) {
275		if (!LIST_EMPTY(&group->members[port]))
276			portmap |= 1 << port;
277	}
278
279	return portmap;
280}
281
282int init_cache(void)
283{
284	int index;
285
286	memset(&hosts, 0, sizeof(hosts));
287	memset(&members, 0, sizeof(members));
288	memset(&groups, 0, sizeof(groups));
289	memset(&routers, 0, sizeof(routers));
290
291	STAILQ_INIT(&hosts.pool);
292	LIST_INIT(&members.free);
293	STAILQ_INIT(&groups.pool);
294	for (index = 0; index < HASH_SIZE; index++) {
295		LIST_INIT(&groups.hash[index]);
296		LIST_INIT(&hosts.hash[index]);
297	}
298
299	init_group(&routers.group);
300
301	set_timer(&groups.timer, group_timer, NULL);
302	set_timer(&routers.timer, router_timer, &routers.group);
303
304	log_cache("%-6s pool(%u x hash) = %u, entries(%u x %u) = %u", "groups",
305	    HASH_SIZE, sizeof(groups), GROUP_POOL_SIZE, sizeof(struct group_entry),
306#ifdef HOST_POOL_STATIC
307	    0 *
308#endif
309	    GROUP_POOL_SIZE * sizeof(struct group_entry));
310	log_cache("%-6s pool = %u, entries(%u x %u) = %u", "member",
311	    sizeof(members), MEMBER_POOL_SIZE, sizeof(struct member_entry),
312	    MEMBER_POOL_SIZE * sizeof(struct member_entry));
313	log_cache("%-6s pool(%u x hash) = %u, entries(%u x %u) = %u", "hosts",
314	    HASH_SIZE, sizeof(hosts), HOST_POOL_SIZE, sizeof(struct host_entry),
315#ifdef GROUP_POOL_STATIC
316	    0 *
317#endif
318	    HOST_POOL_SIZE * sizeof(struct host_entry));
319
320	return 0;
321}
322
323static void group_timer(struct timer_entry *timer, void *data)
324{
325	struct group_entry *group;
326	unsigned long expires, time = now();
327	int portmap;
328
329	expires = time + ~0UL/2;
330	STAILQ_FOREACH(group, &groups.pool, link) {
331		portmap = group->portmap;
332		if (portmap == 0)
333			continue;
334		if (time_after(group->time, time)) {
335			if (time_before(group->time, expires))
336				expires = group->time;
337			continue;
338		} else
339			consume_group(group);
340
341		log_cache("%-6s [" FMT_EA "] - " FMT_PORTS, "expire",
342		    ARG_EA(group->ea), ARG_PORTS(portmap));
343
344		portmap &= ~routers.group.portmap;
345		if (portmap)
346			switch_del_portmap(group->ea, portmap);
347	}
348
349	if (time_before(expires, time + ~0UL/2))
350		mod_timer(timer, expires);
351}
352
353int add_member(unsigned char *maddr, in_addr_t addr, int port, int timeout)
354{
355	struct group_entry *group;
356	struct member_entry *member;
357	struct timer_entry *timer;
358	int portmap;
359
360	if (port < 0 || port > PORT_MAX)
361		return -1;
362
363	group = get_group(maddr, 1);
364	if (group) {
365		portmap = group->portmap;
366
367		group->time = now() + timeout;
368		member = get_member(group, addr, port, 1);
369		if (member)
370			member->time = group->time;
371		group->portmap = get_portmap(group);
372		portmap = (portmap ^ group->portmap) & group->portmap;
373
374		timer = &groups.timer;
375		if (!timer_pending(timer) || time_before(group->time, timer->expires))
376			mod_timer(timer, group->time);
377
378		log_cache("%-6s [" FMT_EA "] + " FMT_PORTS " add " FMT_IP " expires in %d", "member",
379		    ARG_EA(group->ea), ARG_PORTS(portmap), ARG_IP(&addr), timeout / TIMER_HZ);
380	} else
381		portmap = 0;
382
383	if (portmap)
384		switch_add_portmap(maddr, portmap | routers.group.portmap);
385
386	return portmap;
387}
388
389int del_member(unsigned char *maddr, in_addr_t addr, int port)
390{
391	struct group_entry *group;
392	struct member_entry *member;
393	int portmap;
394
395	if (port < 0 || port > PORT_MAX)
396		return -1;
397
398	group = get_group(maddr, 0);
399	if (group) {
400		portmap = group->portmap;
401
402		member = get_member(group, addr, port, 0);
403		if (member)
404			consume_member(member);
405		group->portmap = get_portmap(group);
406		portmap = (portmap ^ group->portmap) & portmap;
407		if (portmap && group->portmap == 0)
408			consume_group(group);
409
410		log_cache("%-6s [" FMT_EA "] - " FMT_PORTS " del " FMT_IP, "member",
411		    ARG_EA(group->ea), ARG_PORTS(portmap), ARG_IP(&addr));
412	} else
413		portmap = 0;
414
415	portmap &= ~routers.group.portmap;
416	if (portmap)
417		switch_del_portmap(maddr, portmap);
418
419	return portmap;
420}
421
422static void router_timer(struct timer_entry *timer, void *data)
423{
424	struct group_entry *group = data;
425	struct member_entry *member, *next;
426	unsigned long time = now();
427	int port, portmap, groupmap;
428
429	portmap = group->portmap;
430	if (portmap < 0)
431		return;
432
433	if (time_after(group->time, time)) {
434		group->time = time + ~0UL/2;
435		for (port = 0; port <= PORT_MAX; port++) {
436			LIST_FOREACH_SAFE(member, &group->members[port], link, next) {
437				if (time_after(member->time, time)) {
438					if (time_before(member->time, group->time))
439						group->time = member->time;
440					continue;
441				} else
442					consume_member(member);
443			}
444		}
445		group->portmap = get_portmap(group);
446		portmap = (portmap ^ group->portmap) & portmap;
447		if (group->portmap)
448			mod_timer(timer, group->time);
449		else
450			consume_group(group);
451	} else
452		consume_group(group);
453
454	log_cache("%-6s [" FMT_EA "] - " FMT_PORTS, "expire",
455	    ARG_EA(group->ea), ARG_PORTS(portmap));
456
457	if (portmap) {
458		STAILQ_FOREACH(group, &groups.pool, link) {
459			groupmap = portmap & ~group->portmap;
460			if (groupmap)
461				switch_del_portmap(group->ea, groupmap);
462		}
463	}
464}
465
466int add_router(in_addr_t addr, int port, int timeout)
467{
468	struct group_entry *group;
469	struct member_entry *member;
470	struct timer_entry *timer;
471	int portmap, groupmap;
472
473	if (port < 0 || port > PORT_MAX)
474		return -1;
475
476	group = &routers.group;
477	if (group) {
478		portmap = group->portmap;
479
480		group->time = now() + timeout;
481		member = get_member(group, addr, port, 1);
482		if (member)
483			member->time = group->time;
484		group->portmap = get_portmap(group);
485		portmap = (portmap ^ group->portmap) & group->portmap;
486
487		timer = &routers.timer;
488		if (!timer_pending(timer) || time_after(timer->expires, group->time))
489			mod_timer(timer, group->time);
490
491		log_cache("%-6s [" FMT_EA "] + " FMT_PORTS " add " FMT_IP " expires in %d", "router",
492		    ARG_EA(group->ea), ARG_PORTS(portmap), ARG_IP(&addr), timeout / TIMER_HZ);
493	} else
494		portmap = 0;
495
496	if (portmap) {
497		STAILQ_FOREACH(group, &groups.pool, link) {
498			groupmap = portmap & ~group->portmap;
499			if (groupmap)
500				switch_add_portmap(group->ea, groupmap);
501		}
502	}
503
504	return portmap;
505}
506
507int expire_members(unsigned char *maddr, int timeout)
508{
509	struct group_entry *group;
510	unsigned long time = now() + timeout;
511
512	if (maddr) {
513		group = get_group(maddr, 0);
514		if (!group)
515			return -1;
516		group->time = time;
517	} else
518	STAILQ_FOREACH(group, &groups.pool, link) {
519		group->time = time;
520	}
521
522	log_cache("%-6s fast expire %s in %d", "expire", maddr ? "group" : "all", timeout / TIMER_HZ);
523	if (!timer_pending(&groups.timer) || time_after(groups.timer.expires, time))
524		mod_timer(&groups.timer, time);
525
526	return 0;
527}
528
529int purge_cache(void)
530{
531	struct group_entry *group;
532	struct member_entry *member, *next_member;
533	struct host_entry *host;
534
535	del_timer(&groups.timer);
536	del_timer(&routers.timer);
537
538	while ((group = STAILQ_FIRST(&groups.pool))) {
539		consume_group(group);
540		LIST_REMOVE(group, hash);
541		STAILQ_REMOVE_HEAD(&groups.pool, link);
542		switch_clr_portmap(group->ea);
543#ifndef GROUP_POOL_STATIC
544		free(group);
545#endif
546	}
547	consume_group(&routers.group);
548
549	LIST_FOREACH_SAFE(member, &members.free, link, next_member) {
550		LIST_REMOVE(member, link);
551		free(member);
552	}
553
554	while ((host = STAILQ_FIRST(&hosts.pool))) {
555		STAILQ_REMOVE_HEAD(&hosts.pool, link);
556		LIST_REMOVE(host, hash);
557#ifndef HOST_POOL_STATIC
558		free(host);
559#endif
560	}
561
562	return 0;
563}
564