1/*
2 * Copyright 2007, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
4 *
5 * Authors:
6 *      Hugo Santos, hugosantos@gmail.com
7 */
8#ifndef _PRIVATE_MULTICAST_H_
9#define _PRIVATE_MULTICAST_H_
10
11
12#include <util/DoublyLinkedList.h>
13#include <util/OpenHashTable.h>
14
15#include <net_datalink.h>
16
17#include <netinet/in.h>
18
19#include <utility>
20
21
22struct net_buffer;
23struct net_protocol;
24
25// This code is template'ized as it is reusable for IPv6
26
27template<typename Addressing> class MulticastFilter;
28template<typename Addressing> class MulticastGroupInterface;
29
30// TODO move this elsewhere...
31struct IPv4Multicast {
32	typedef struct in_addr AddressType;
33	typedef struct ipv4_protocol ProtocolType;
34	typedef MulticastGroupInterface<IPv4Multicast> GroupInterface;
35
36	static status_t JoinGroup(GroupInterface *);
37	static status_t LeaveGroup(GroupInterface *);
38
39	static const in_addr &AddressFromSockAddr(const sockaddr *sockaddr)
40		{ return ((const sockaddr_in *)sockaddr)->sin_addr; }
41	static size_t HashAddress(const in_addr &address)
42		{ return address.s_addr; }
43};
44
45template<typename AddressType>
46class AddressSet {
47	struct ContainedAddress : DoublyLinkedListLinkImpl<ContainedAddress> {
48		AddressType address;
49	};
50
51	typedef DoublyLinkedList<ContainedAddress> AddressList;
52
53public:
54	AddressSet()
55		: fCount(0) {}
56
57	~AddressSet() { Clear(); }
58
59	status_t Add(const AddressType &address)
60	{
61		if (Has(address))
62			return B_OK;
63
64		ContainedAddress *container = new ContainedAddress();
65		if (container == NULL)
66			return B_NO_MEMORY;
67
68		container->address = address;
69		fAddresses.Add(container);
70
71		return B_OK;
72	}
73
74	void Remove(const AddressType &address)
75	{
76		ContainedAddress *container = _Get(address);
77		if (container == NULL)
78			return;
79
80		fAddresses.Remove(container);
81		delete container;
82	}
83
84	bool Has(const AddressType &address) const
85	{
86		return _Get(address) != NULL;
87	}
88
89	bool IsEmpty() const { return fAddresses.IsEmpty(); }
90
91	void Clear()
92	{
93		while (!fAddresses.IsEmpty())
94			Remove(fAddresses.Head()->address);
95	}
96
97	class Iterator {
98	public:
99		Iterator(const AddressList &addresses)
100			: fBaseIterator(addresses.GetIterator()) {}
101
102		bool HasNext() const { return fBaseIterator.HasNext(); }
103		AddressType &Next() { return fBaseIterator.Next()->address; }
104
105	private:
106		typename AddressList::ConstIterator fBaseIterator;
107	};
108
109	Iterator GetIterator() const { return Iterator(fAddresses); }
110
111private:
112	ContainedAddress *_Get(const AddressType &address) const
113	{
114		typename AddressList::ConstIterator it = fAddresses.GetIterator();
115		while (it.HasNext()) {
116			ContainedAddress *container = it.Next();
117			if (container->address == address)
118				return container;
119		}
120		return NULL;
121	}
122
123	AddressList fAddresses;
124	int fCount;
125};
126
127
128template<typename Addressing>
129class MulticastGroupInterface {
130public:
131	typedef MulticastGroupInterface<Addressing> ThisType;
132	typedef typename Addressing::AddressType AddressType;
133	typedef MulticastFilter<Addressing> Filter;
134	typedef ::AddressSet<AddressType> AddressSet;
135
136	enum FilterMode {
137		kInclude,
138		kExclude
139	};
140
141	MulticastGroupInterface(Filter *parent, const AddressType &address,
142		net_interface *interface);
143	~MulticastGroupInterface();
144
145	Filter *Parent() const { return fParent; }
146
147	const AddressType &Address() const { return fMulticastAddress; }
148	net_interface *Interface() const { return fInterface; }
149
150	status_t Add();
151	status_t Drop();
152	status_t BlockSource(const AddressType &sourceAddress);
153	status_t UnblockSource(const AddressType &sourceAddress);
154	status_t AddSSM(const AddressType &sourceAddress);
155	status_t DropSSM(const AddressType &sourceAddress);
156
157	bool IsEmpty() const;
158	void Clear();
159
160	FilterMode Mode() const { return fFilterMode; }
161	const AddressSet &Sources() const { return fAddresses; }
162
163	bool FilterAccepts(net_buffer *buffer) const;
164
165	struct HashDefinition {
166		typedef std::pair<const AddressType *, uint32> KeyType;
167		typedef ThisType ValueType;
168
169		size_t HashKey(const KeyType &key) const
170			{ return Addressing::HashAddress(*key.first) ^ key.second; }
171		size_t Hash(ValueType *value) const
172			{ return HashKey(std::make_pair(&value->Address(),
173				value->Interface()->index)); }
174		bool Compare(const KeyType &key, ValueType *value) const
175			{ return value->Interface()->index == key.second
176				&& value->Address().s_addr == key.first->s_addr; }
177		MulticastGroupInterface*& GetLink(ValueType *value) const
178			{ return value->HashLink(); }
179	};
180
181	MulticastGroupInterface*& HashLink() { return fLink; }
182	MulticastGroupInterface*& MulticastGroupsHashLink() { return fMulticastGroupsLink; }
183
184private:
185	// for g++ 2.95
186	friend struct HashDefinition;
187
188	Filter *fParent;
189	AddressType fMulticastAddress;
190	net_interface *fInterface;
191	FilterMode fFilterMode;
192	AddressSet fAddresses;
193	MulticastGroupInterface* fLink;
194	MulticastGroupInterface* fMulticastGroupsLink;
195};
196
197template<typename Addressing>
198class MulticastFilter {
199public:
200	typedef typename Addressing::AddressType AddressType;
201	typedef typename Addressing::ProtocolType ProtocolType;
202	typedef MulticastGroupInterface<Addressing> GroupInterface;
203
204	MulticastFilter(ProtocolType *parent);
205	~MulticastFilter();
206
207	ProtocolType *Socket() const { return fParent; }
208
209	status_t GetState(const AddressType &groupAddress,
210		net_interface *interface, GroupInterface* &state, bool create);
211	void ReturnState(GroupInterface *state);
212
213private:
214	typedef typename GroupInterface::HashDefinition HashDefinition;
215	typedef BOpenHashTable<HashDefinition> States;
216
217	void _ReturnState(GroupInterface *state);
218
219	ProtocolType *fParent;
220	States fStates;
221};
222
223#endif
224