1/*
2 * Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de. All rights reserved.
3 * Distributed under the terms of the MIT License.
4 */
5
6#include <grp.h>
7
8#include <errno.h>
9#include <string.h>
10#include <unistd.h>
11
12#include <new>
13
14#include <OS.h>
15
16#include <errno_private.h>
17#include <libroot_private.h>
18#include <RegistrarDefs.h>
19#include <user_group.h>
20
21#include <util/KMessage.h>
22
23
24using BPrivate::UserGroupLocker;
25using BPrivate::relocate_pointer;
26
27
28static KMessage sGroupDBReply;
29static group** sGroupEntries = NULL;
30static size_t sGroupEntryCount = 0;
31static size_t sIterationIndex = 0;
32
33static struct group sGroupBuffer;
34static char sGroupStringBuffer[MAX_GROUP_BUFFER_SIZE];
35
36
37static status_t
38query_group_entry(const char* name, gid_t _gid, struct group *group,
39	char *buffer, size_t bufferSize, struct group **_result)
40{
41	*_result = NULL;
42
43	KMessage message(BPrivate::B_REG_GET_GROUP);
44	if (name)
45		message.AddString("name", name);
46	else
47		message.AddInt32("gid", _gid);
48
49	KMessage reply;
50	status_t error = BPrivate::send_authentication_request_to_registrar(message,
51		reply);
52	if (error != B_OK) {
53		return error == ENOENT ? B_OK : error;
54	}
55
56	int32 gid;
57	const char* password;
58
59	if ((error = reply.FindInt32("gid", &gid)) != B_OK
60		|| (error = reply.FindString("name", &name)) != B_OK
61		|| (error = reply.FindString("password", &password)) != B_OK) {
62		return error;
63	}
64
65	const char* members[MAX_GROUP_MEMBER_COUNT];
66	int memberCount = 0;
67	for (int32 index = 0; memberCount < MAX_GROUP_MEMBER_COUNT; index++) {
68		if (reply.FindString("members", index, members + memberCount) != B_OK)
69			break;
70		memberCount++;
71	}
72
73	error = BPrivate::copy_group_to_buffer(name, password, gid, members,
74		memberCount, group, buffer, bufferSize);
75	if (error == B_OK)
76		*_result = group;
77
78	return error;
79}
80
81
82static status_t
83init_group_db()
84{
85	if (sGroupEntries != NULL)
86		return B_OK;
87
88	// ask the registrar
89	KMessage message(BPrivate::B_REG_GET_GROUP_DB);
90	status_t error = BPrivate::send_authentication_request_to_registrar(message,
91		sGroupDBReply);
92	if (error != B_OK)
93		return error;
94
95	// unpack the reply
96	int32 count;
97	group** entries;
98	int32 numBytes;
99	if ((error = sGroupDBReply.FindInt32("count", &count)) != B_OK
100		|| (error = sGroupDBReply.FindData("entries", B_RAW_TYPE,
101				(const void**)&entries, &numBytes)) != B_OK) {
102		return error;
103	}
104
105	// relocate the entries
106	addr_t baseAddress = (addr_t)entries;
107	for (int32 i = 0; i < count; i++) {
108		group* entry = relocate_pointer(baseAddress, entries[i]);
109		relocate_pointer(baseAddress, entry->gr_name);
110		relocate_pointer(baseAddress, entry->gr_passwd);
111		relocate_pointer(baseAddress, entry->gr_mem);
112		int32 k = 0;
113		for (; entry->gr_mem[k] != (void*)-1; k++)
114			relocate_pointer(baseAddress, entry->gr_mem[k]);
115		entry->gr_mem[k] = NULL;
116	}
117
118	sGroupEntries = entries;
119	sGroupEntryCount = count;
120
121	return B_OK;
122}
123
124
125// #pragma mark -
126
127
128struct group*
129getgrent(void)
130{
131	struct group* result = NULL;
132	int status = getgrent_r(&sGroupBuffer, sGroupStringBuffer,
133		sizeof(sGroupStringBuffer), &result);
134	if (status != 0)
135		__set_errno(status);
136	return result;
137}
138
139
140int
141getgrent_r(struct group* group, char* buffer, size_t bufferSize,
142	struct group** _result)
143{
144	UserGroupLocker _;
145
146	int status = B_NO_MEMORY;
147
148	*_result = NULL;
149
150	if ((status = init_group_db()) == B_OK) {
151		if (sIterationIndex >= sGroupEntryCount)
152			return ENOENT;
153
154		status = BPrivate::copy_group_to_buffer(
155			sGroupEntries[sIterationIndex], group, buffer, bufferSize);
156
157		if (status == B_OK) {
158			sIterationIndex++;
159			*_result = group;
160		}
161	}
162
163	return status;
164}
165
166
167void
168setgrent(void)
169{
170	UserGroupLocker _;
171
172	sIterationIndex = 0;
173}
174
175
176void
177endgrent(void)
178{
179	UserGroupLocker locker;
180
181	sGroupDBReply.Unset();
182	sGroupEntries = NULL;
183	sGroupEntryCount = 0;
184	sIterationIndex = 0;
185}
186
187
188struct group *
189getgrnam(const char *name)
190{
191	struct group* result = NULL;
192	int status = getgrnam_r(name, &sGroupBuffer, sGroupStringBuffer,
193		sizeof(sGroupStringBuffer), &result);
194	if (status != 0)
195		__set_errno(status);
196	return result;
197}
198
199
200int
201getgrnam_r(const char *name, struct group *group, char *buffer,
202	size_t bufferSize, struct group **_result)
203{
204	return query_group_entry(name, 0, group, buffer, bufferSize, _result);
205}
206
207
208struct group *
209getgrgid(gid_t gid)
210{
211	struct group* result = NULL;
212	int status = getgrgid_r(gid, &sGroupBuffer, sGroupStringBuffer,
213		sizeof(sGroupStringBuffer), &result);
214	if (status != 0)
215		__set_errno(status);
216	return result;
217}
218
219
220int
221getgrgid_r(gid_t gid, struct group *group, char *buffer,
222	size_t bufferSize, struct group **_result)
223{
224	return query_group_entry(NULL, gid, group, buffer, bufferSize, _result);
225}
226
227
228int
229getgrouplist(const char* user, gid_t baseGroup, gid_t* groupList,
230	int* groupCount)
231{
232	int maxGroupCount = *groupCount;
233	*groupCount = 0;
234
235	status_t error = B_OK;
236
237	// prepare request
238	KMessage message(BPrivate::B_REG_GET_USER_GROUPS);
239	if (message.AddString("name", user) != B_OK
240		|| message.AddInt32("max count", maxGroupCount) != B_OK) {
241		return -1;
242	}
243
244	// send request
245	KMessage reply;
246	error = BPrivate::send_authentication_request_to_registrar(message, reply);
247	if (error != B_OK)
248		return -1;
249
250	// unpack reply
251	int32 count;
252	const int32* groups;
253	int32 groupsSize;
254	if (reply.FindInt32("count", &count) != B_OK
255		|| reply.FindData("groups", B_INT32_TYPE, (const void**)&groups,
256				&groupsSize) != B_OK) {
257		return -1;
258	}
259
260	memcpy(groupList, groups, groupsSize);
261	*groupCount = count;
262
263	// add the base group
264	if (*groupCount < maxGroupCount)
265		groupList[*groupCount] = baseGroup;
266	++*groupCount;
267
268	return *groupCount <= maxGroupCount ? *groupCount : -1;
269}
270
271
272int
273initgroups(const char* user, gid_t baseGroup)
274{
275	gid_t groups[NGROUPS_MAX + 1];
276	int groupCount = NGROUPS_MAX + 1;
277	if (getgrouplist(user, baseGroup, groups, &groupCount) < 0)
278		return -1;
279
280	return setgroups(groupCount, groups);
281}
282