1// AuthenticationServer.cpp
2
3#include "AuthenticationServer.h"
4
5#include <new>
6
7#include <HashMap.h>
8#include <HashString.h>
9#include <util/KMessage.h>
10
11#include "AuthenticationPanel.h"
12#include "AuthenticationServerDefs.h"
13#include "DebugSupport.h"
14#include "TaskManager.h"
15
16
17// Authentication
18class AuthenticationServer::Authentication {
19public:
20	Authentication()
21		: fUser(),
22		  fPassword()
23	{
24	}
25
26	Authentication(const char* user, const char* password)
27		: fUser(user),
28		  fPassword(password)
29	{
30	}
31
32	status_t SetTo(const char* user, const char* password)
33	{
34		if (fUser.SetTo(user) && fPassword.SetTo(password))
35			return B_OK;
36		return B_NO_MEMORY;
37	}
38
39	bool IsValid() const
40	{
41		return (fUser.GetLength() > 0);
42	}
43
44	const char* GetUser() const
45	{
46		return fUser.GetString();
47	}
48
49	const char* GetPassword() const
50	{
51		return fPassword.GetString();
52	}
53
54private:
55	HashString	fUser;
56	HashString	fPassword;
57};
58
59// ServerKey
60class AuthenticationServer::ServerKey {
61public:
62	ServerKey()
63		: fContext(),
64		  fServer()
65	{
66	}
67
68	ServerKey(const char* context, const char* server)
69		: fContext(context),
70		  fServer(server)
71	{
72	}
73
74	ServerKey(const ServerKey& other)
75		: fContext(other.fContext),
76		  fServer(other.fServer)
77	{
78	}
79
80	uint32 GetHashCode() const
81	{
82		return fContext.GetHashCode() * 17 + fServer.GetHashCode();
83	}
84
85	ServerKey& operator=(const ServerKey& other)
86	{
87		fContext = other.fContext;
88		fServer = other.fServer;
89		return *this;
90	}
91
92	bool operator==(const ServerKey& other) const
93	{
94		return (fContext == other.fContext && fServer == other.fServer);
95	}
96
97	bool operator!=(const ServerKey& other) const
98	{
99		return !(*this == other);
100	}
101
102private:
103	HashString	fContext;
104	HashString	fServer;
105};
106
107// ServerEntry
108class AuthenticationServer::ServerEntry {
109public:
110	ServerEntry()
111		: fDefaultAuthentication(),
112		  fUseDefaultAuthentication(false)
113	{
114	}
115
116	~ServerEntry()
117	{
118		// delete the authentications
119		for (AuthenticationMap::Iterator it = fAuthentications.GetIterator();
120			 it.HasNext();) {
121			delete it.Next().value;
122		}
123	}
124
125	void SetUseDefaultAuthentication(bool useDefaultAuthentication)
126	{
127		fUseDefaultAuthentication = useDefaultAuthentication;
128	}
129
130	bool UseDefaultAuthentication() const
131	{
132		return fUseDefaultAuthentication;
133	}
134
135	status_t SetDefaultAuthentication(const char* user, const char* password)
136	{
137		return fDefaultAuthentication.SetTo(user, password);
138	}
139
140	const Authentication& GetDefaultAuthentication() const
141	{
142		return fDefaultAuthentication;
143	}
144
145	status_t SetAuthentication(const char* share, const char* user,
146		const char* password)
147	{
148		// check, if an entry already exists for the share -- if it does,
149		// just set it
150		Authentication* authentication = fAuthentications.Get(share);
151		if (authentication)
152			return authentication->SetTo(user, password);
153		// the entry does not exist yet: create and add a new one
154		authentication = new(std::nothrow) Authentication;
155		if (!authentication)
156			return B_NO_MEMORY;
157		status_t error = authentication->SetTo(user, password);
158		if (error == B_OK)
159			error = fAuthentications.Put(share, authentication);
160		if (error != B_OK)
161			delete authentication;
162		return error;
163	}
164
165	Authentication* GetAuthentication(const char* share) const
166	{
167		return fAuthentications.Get(share);
168	}
169
170private:
171	typedef HashMap<HashString, Authentication*> AuthenticationMap;
172
173	Authentication		fDefaultAuthentication;
174	bool				fUseDefaultAuthentication;
175	AuthenticationMap	fAuthentications;
176};
177
178// ServerEntryMap
179struct AuthenticationServer::ServerEntryMap
180	: HashMap<ServerKey, ServerEntry*> {
181};
182
183// UserDialogTask
184class AuthenticationServer::UserDialogTask : public Task {
185public:
186	UserDialogTask(AuthenticationServer* authenticationServer,
187		const char* context, const char* server, const char* share,
188		bool badPassword, port_id replyPort,
189		int32 replyToken)
190		: Task("user dialog task"),
191		  fAuthenticationServer(authenticationServer),
192		  fContext(context),
193		  fServer(server),
194		  fShare(share),
195		  fBadPassword(badPassword),
196		  fReplyPort(replyPort),
197		  fReplyToken(replyToken),
198		  fPanel(NULL)
199	{
200	}
201
202	virtual status_t Execute()
203	{
204		// open the panel
205		char user[B_OS_NAME_LENGTH];
206		char password[B_OS_NAME_LENGTH];
207		bool keep = true;
208		fPanel = new(std::nothrow) AuthenticationPanel();
209		status_t error = (fPanel ? B_OK : B_NO_MEMORY);
210		bool cancelled = false;
211		HashString defaultUser;
212		HashString defaultPassword;
213		fAuthenticationServer->_GetAuthentication(fContext.GetString(),
214			fServer.GetString(), NULL, &defaultUser, &defaultPassword);
215		if (error == B_OK) {
216			cancelled = fPanel->GetAuthentication(fServer.GetString(),
217				fShare.GetString(), defaultUser.GetString(),
218				defaultPassword.GetString(), false, fBadPassword, user,
219				password, &keep);
220		}
221		fPanel = NULL;
222		// send the reply
223		if (error != B_OK) {
224			fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
225				error, true, NULL, NULL);
226		} else if (cancelled) {
227			fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
228				B_OK, true, NULL, NULL);
229		} else {
230			fAuthenticationServer->_AddAuthentication(fContext.GetString(),
231				fServer.GetString(), fShare.GetString(), user, password,
232				keep);
233			fAuthenticationServer->_SendRequestReply(fReplyPort, fReplyToken,
234				B_OK, false, user, password);
235		}
236		return error;
237	}
238
239	virtual void Stop()
240	{
241		if (fPanel)
242			fPanel->Cancel();
243	}
244
245private:
246	AuthenticationServer*	fAuthenticationServer;
247	HashString				fContext;
248	HashString				fServer;
249	HashString				fShare;
250	bool					fBadPassword;
251	port_id					fReplyPort;
252	int32					fReplyToken;
253	AuthenticationPanel*	fPanel;
254};
255
256
257// constructor
258AuthenticationServer::AuthenticationServer()
259	:
260	BApplication("application/x-vnd.haiku-authentication_server"),
261	fLock(),
262	fRequestPort(-1),
263	fRequestThread(-1),
264	fServerEntries(NULL),
265	fTerminating(false)
266{
267}
268
269// destructor
270AuthenticationServer::~AuthenticationServer()
271{
272	fTerminating = true;
273	// terminate the request thread
274	if (fRequestPort >= 0)
275		delete_port(fRequestPort);
276	if (fRequestThread >= 0) {
277		int32 result;
278		wait_for_thread(fRequestPort, &result);
279	}
280	// delete the server entries
281	for (ServerEntryMap::Iterator it = fServerEntries->GetIterator();
282		 it.HasNext();) {
283		delete it.Next().value;
284	}
285}
286
287// Init
288status_t
289AuthenticationServer::Init()
290{
291	// create the server entry map
292	fServerEntries = new(std::nothrow) ServerEntryMap;
293	if (!fServerEntries)
294		return B_NO_MEMORY;
295	status_t error = fServerEntries->InitCheck();
296	if (error != B_OK)
297		return error;
298	// create the request port
299	fRequestPort = create_port(10, kAuthenticationServerPortName);
300	if (fRequestPort < 0)
301		return fRequestPort;
302	// spawn the request thread
303	fRequestThread = spawn_thread(&_RequestThreadEntry, "request thread",
304		B_NORMAL_PRIORITY, this);
305	if (fRequestThread < 0)
306		return fRequestThread;
307	resume_thread(fRequestThread);
308	return B_OK;
309}
310
311// _RequestThreadEntry
312int32
313AuthenticationServer::_RequestThreadEntry(void* data)
314{
315	return ((AuthenticationServer*)data)->_RequestThread();
316}
317
318// _RequestThread
319int32
320AuthenticationServer::_RequestThread()
321{
322	TaskManager taskManager;
323	while (!fTerminating) {
324		taskManager.RemoveDoneTasks();
325		// read the request
326		KMessage request;
327		status_t error = request.ReceiveFrom(fRequestPort);
328		if (error != B_OK)
329			continue;
330		// get the parameters
331		const char* context = NULL;
332		const char* server = NULL;
333		const char* share = NULL;
334		bool badPassword = true;
335		request.FindString("context", &context);
336		request.FindString("server", &server);
337		request.FindString("share", &share);
338		request.FindBool("badPassword", &badPassword);
339		if (!context || !server || !share)
340			continue;
341		HashString foundUser;
342		HashString foundPassword;
343		if (!badPassword && _GetAuthentication(context, server, share,
344			&foundUser, &foundPassword)) {
345			_SendRequestReply(request.ReplyPort(), request.ReplyToken(),
346				error, false, foundUser.GetString(), foundPassword.GetString());
347		} else {
348			// we need to ask the user: create a task that does it
349			UserDialogTask* task = new(std::nothrow) UserDialogTask(this,
350				context, server, share, badPassword, request.ReplyPort(),
351				request.ReplyToken());
352			if (!task) {
353				ERROR("AuthenticationServer::_RequestThread(): ERROR: "
354					"failed to allocate ");
355				continue;
356			}
357			status_t error = taskManager.RunTask(task);
358			if (error != B_OK) {
359				ERROR("AuthenticationServer::_RequestThread(): Failed to "
360					"start server info task: %s\n", strerror(error));
361				continue;
362			}
363		}
364	}
365	return 0;
366}
367
368// _GetAuthentication
369/*!
370	If share is NULL, the default authentication for the server is returned.
371*/
372bool
373AuthenticationServer::_GetAuthentication(const char* context,
374	const char* server, const char* share, HashString* user,
375	HashString* password)
376{
377	if (!context || !server || !user || !password)
378		return B_BAD_VALUE;
379	// get the server entry
380	AutoLocker<BLocker> _(fLock);
381	ServerKey key(context, server);
382	ServerEntry* serverEntry = fServerEntries->Get(key);
383	if (!serverEntry)
384		return false;
385	// get the authentication
386	const Authentication* authentication = NULL;
387	if (share) {
388		serverEntry->GetAuthentication(share);
389		if (!authentication && serverEntry->UseDefaultAuthentication())
390			authentication = &serverEntry->GetDefaultAuthentication();
391	} else
392		authentication = &serverEntry->GetDefaultAuthentication();
393	if (!authentication || !authentication->IsValid())
394		return false;
395	return (user->SetTo(authentication->GetUser())
396		&& password->SetTo(authentication->GetPassword()));
397}
398
399// _AddAuthentication
400status_t
401AuthenticationServer::_AddAuthentication(const char* context,
402	const char* server, const char* share, const char* user,
403	const char* password, bool makeDefault)
404{
405	AutoLocker<BLocker> _(fLock);
406	ServerKey key(context, server);
407	// get the server entry
408	ServerEntry* serverEntry = fServerEntries->Get(key);
409	if (!serverEntry) {
410		// server entry does not exist yet: create a new one
411		serverEntry = new(std::nothrow) ServerEntry;
412		if (!serverEntry)
413			return B_NO_MEMORY;
414		status_t error = fServerEntries->Put(key, serverEntry);
415		if (error != B_OK) {
416			delete serverEntry;
417			return error;
418		}
419	}
420	// put the authentication
421	status_t error = serverEntry->SetAuthentication(share, user, password);
422	if (error == B_OK) {
423		if (makeDefault || !serverEntry->UseDefaultAuthentication())
424			serverEntry->SetDefaultAuthentication(user, password);
425		if (makeDefault)
426			serverEntry->SetUseDefaultAuthentication(true);
427	}
428	return error;
429}
430
431// _SendRequestReply
432status_t
433AuthenticationServer::_SendRequestReply(port_id port, int32 token,
434	status_t error, bool cancelled, const char* user, const char* password)
435{
436	// prepare the reply
437	KMessage reply;
438	reply.AddInt32("error", error);
439	if (error == B_OK) {
440		reply.AddBool("cancelled", cancelled);
441		if (!cancelled) {
442			reply.AddString("user", user);
443			reply.AddString("password", password);
444		}
445	}
446	// send the reply
447	return reply.SendTo(port, token);
448}
449
450
451// main
452int
453main()
454{
455	AuthenticationServer app;
456	status_t error = app.Init();
457	if (error != B_OK)
458		return 1;
459	app.Run();
460	return 0;
461}
462
463