1/* -*- Mode: C; tab-width: 4 -*-
2 *
3 * Copyright (c) 2002-2004 Apple Computer, Inc. All rights reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Poll.h"
19#include <stdarg.h>
20#include <stddef.h>
21#include <stdio.h>
22#include <stdlib.h>
23#include <string.h>
24#include <winsock2.h>
25#include <ws2tcpip.h>
26#include <windows.h>
27#include <process.h>
28#include "GenLinkedList.h"
29#include "DebugServices.h"
30
31
32typedef struct PollSource_struct
33{
34	SOCKET socket;
35	HANDLE handle;
36	void   *context;
37
38	union
39	{
40		mDNSPollSocketCallback socket;
41		mDNSPollEventCallback event;
42	} callback;
43
44	struct Worker_struct		*worker;
45	struct PollSource_struct	*next;
46
47} PollSource;
48
49
50typedef struct Worker_struct
51{
52	HANDLE					thread;		// NULL for main worker
53	unsigned				id;			// 0 for main worker
54
55	HANDLE					start;		// NULL for main worker
56	HANDLE					stop;		// NULL for main worker
57	BOOL					done;		// Not used for main worker
58
59	DWORD					numSources;
60	PollSource				*sources[ MAXIMUM_WAIT_OBJECTS ];
61	HANDLE					handles[ MAXIMUM_WAIT_OBJECTS ];
62	DWORD					result;
63	struct Worker_struct	*next;
64} Worker;
65
66
67typedef struct Poll_struct
68{
69	mDNSBool		setup;
70	HANDLE			wakeup;
71	GenLinkedList	sources;
72	DWORD			numSources;
73	Worker			main;
74	GenLinkedList	workers;
75	HANDLE			workerHandles[ MAXIMUM_WAIT_OBJECTS ];
76	DWORD			numWorkers;
77
78} Poll;
79
80
81/*
82 * Poll Methods
83 */
84
85mDNSlocal mStatus			PollSetup();
86mDNSlocal mStatus			PollRegisterSource( PollSource *source );
87mDNSlocal void				PollUnregisterSource( PollSource *source );
88mDNSlocal mStatus			PollStartWorkers();
89mDNSlocal mStatus			PollStopWorkers();
90mDNSlocal void				PollRemoveWorker( Worker *worker );
91
92
93/*
94 * Worker Methods
95 */
96
97mDNSlocal mStatus			WorkerInit( Worker *worker );
98mDNSlocal void				WorkerFree( Worker *worker );
99mDNSlocal void				WorkerRegisterSource( Worker *worker, PollSource *source );
100mDNSlocal int				WorkerSourceToIndex( Worker *worker, PollSource *source );
101mDNSlocal void				WorkerUnregisterSource( Worker *worker, PollSource *source );
102mDNSlocal void				WorkerDispatch( Worker *worker);
103mDNSlocal void CALLBACK		WorkerWakeupNotification( HANDLE event, void *context );
104mDNSlocal unsigned WINAPI	WorkerMain( LPVOID inParam );
105
106
107static void
108ShiftDown( void * arr, size_t arraySize, size_t itemSize, int index )
109{
110    memmove( ( ( unsigned char* ) arr ) + ( ( index - 1 ) * itemSize ), ( ( unsigned char* ) arr ) + ( index * itemSize ), ( arraySize - index ) * itemSize );
111}
112
113
114#define	DEBUG_NAME	"[mDNSWin32] "
115#define gMDNSRecord mDNSStorage
116mDNSlocal Poll gPoll = { mDNSfalse, NULL };
117
118#define LogErr( err, FUNC ) LogMsg( "%s:%d - %s failed: %d\n", __FUNCTION__, __LINE__, FUNC, err );
119
120
121mStatus
122mDNSPollRegisterSocket( SOCKET socket, int networkEvents, mDNSPollSocketCallback callback, void *context )
123{
124	PollSource	*source = NULL;
125	HANDLE		event = INVALID_HANDLE_VALUE;
126	mStatus		err = mStatus_NoError;
127
128	if ( !gPoll.setup )
129	{
130		err = PollSetup();
131		require_noerr( err, exit );
132	}
133
134	source = malloc( sizeof( PollSource ) );
135	require_action( source, exit, err = mStatus_NoMemoryErr );
136
137	event = WSACreateEvent();
138	require_action( event, exit, err = mStatus_NoMemoryErr );
139
140	err = WSAEventSelect( socket, event, networkEvents );
141	require_noerr( err, exit );
142
143	source->socket = socket;
144	source->handle = event;
145	source->callback.socket = callback;
146	source->context = context;
147
148	err = PollRegisterSource( source );
149	require_noerr( err, exit );
150
151exit:
152
153	if ( err != mStatus_NoError )
154	{
155		if ( event != INVALID_HANDLE_VALUE )
156		{
157			WSACloseEvent( event );
158		}
159
160		if ( source != NULL )
161		{
162			free( source );
163		}
164	}
165
166	return err;
167}
168
169
170void
171mDNSPollUnregisterSocket( SOCKET socket )
172{
173	PollSource	*source;
174
175	for ( source = gPoll.sources.Head; source; source = source->next )
176	{
177		if ( source->socket == socket )
178		{
179			break;
180		}
181	}
182
183	if ( source )
184	{
185		WSACloseEvent( source->handle );
186		PollUnregisterSource( source );
187		free( source );
188	}
189}
190
191
192mStatus
193mDNSPollRegisterEvent( HANDLE event, mDNSPollEventCallback callback, void *context )
194{
195	PollSource	*source = NULL;
196	mStatus		err = mStatus_NoError;
197
198	if ( !gPoll.setup )
199	{
200		err = PollSetup();
201		require_noerr( err, exit );
202	}
203
204	source = malloc( sizeof( PollSource ) );
205	require_action( source, exit, err = mStatus_NoMemoryErr );
206
207	source->socket = INVALID_SOCKET;
208	source->handle = event;
209	source->callback.event = callback;
210	source->context = context;
211
212	err = PollRegisterSource( source );
213	require_noerr( err, exit );
214
215exit:
216
217	if ( err != mStatus_NoError )
218	{
219		if ( source != NULL )
220		{
221			free( source );
222		}
223	}
224
225	return err;
226}
227
228
229void
230mDNSPollUnregisterEvent( HANDLE event )
231{
232	PollSource	*source;
233
234	for ( source = gPoll.sources.Head; source; source = source->next )
235	{
236		if ( source->handle == event )
237		{
238			break;
239		}
240	}
241
242	if ( source )
243	{
244		PollUnregisterSource( source );
245		free( source );
246	}
247}
248
249
250mStatus
251mDNSPoll( DWORD msec )
252{
253	mStatus err = mStatus_NoError;
254
255	if ( gPoll.numWorkers > 0 )
256	{
257		err = PollStartWorkers();
258		require_noerr( err, exit );
259	}
260
261	gPoll.main.result = WaitForMultipleObjects( gPoll.main.numSources, gPoll.main.handles, FALSE, msec );
262	err = translate_errno( ( gPoll.main.result != WAIT_FAILED ), ( mStatus ) GetLastError(), kUnknownErr );
263	if ( err ) LogErr( err, "WaitForMultipleObjects()" );
264	require_action( gPoll.main.result != WAIT_FAILED, exit, err = ( mStatus ) GetLastError() );
265
266	if ( gPoll.numWorkers > 0 )
267	{
268		err = PollStopWorkers();
269		require_noerr( err, exit );
270	}
271
272	WorkerDispatch( &gPoll.main );
273
274exit:
275
276	return ( err );
277}
278
279
280mDNSlocal mStatus
281PollSetup()
282{
283	mStatus err = mStatus_NoError;
284
285	if ( !gPoll.setup )
286	{
287		memset( &gPoll, 0, sizeof( gPoll ) );
288
289		InitLinkedList( &gPoll.sources, offsetof( PollSource, next ) );
290		InitLinkedList( &gPoll.workers, offsetof( Worker, next ) );
291
292		gPoll.wakeup = CreateEvent( NULL, TRUE, FALSE, NULL );
293		require_action( gPoll.wakeup, exit, err = mStatus_NoMemoryErr );
294
295		err = WorkerInit( &gPoll.main );
296		require_noerr( err, exit );
297
298		gPoll.setup = mDNStrue;
299	}
300
301exit:
302
303	return err;
304}
305
306
307mDNSlocal mStatus
308PollRegisterSource( PollSource *source )
309{
310	Worker	*worker = NULL;
311	mStatus err = mStatus_NoError;
312
313	AddToTail( &gPoll.sources, source );
314	gPoll.numSources++;
315
316	// First check our main worker. In most cases, we won't have to worry about threads
317
318	if ( gPoll.main.numSources < MAXIMUM_WAIT_OBJECTS )
319	{
320		WorkerRegisterSource( &gPoll.main, source );
321	}
322	else
323	{
324		// Try to find a thread to use that we've already created
325
326		for ( worker = gPoll.workers.Head; worker; worker = worker->next )
327		{
328			if ( worker->numSources < MAXIMUM_WAIT_OBJECTS )
329			{
330				WorkerRegisterSource( worker, source );
331				break;
332			}
333		}
334
335		// If not, then create a worker and make a thread to run it in
336
337		if ( !worker )
338		{
339			worker = ( Worker* ) malloc( sizeof( Worker ) );
340			require_action( worker, exit, err = mStatus_NoMemoryErr );
341
342			memset( worker, 0, sizeof( Worker ) );
343
344			worker->start = CreateEvent( NULL, FALSE, FALSE, NULL );
345			require_action( worker->start, exit, err = mStatus_NoMemoryErr );
346
347			worker->stop = CreateEvent( NULL, FALSE, FALSE, NULL );
348			require_action( worker->stop, exit, err = mStatus_NoMemoryErr );
349
350			err = WorkerInit( worker );
351			require_noerr( err, exit );
352
353			// Create thread with _beginthreadex() instead of CreateThread() to avoid
354			// memory leaks when using static run-time libraries.
355			// See <http://msdn.microsoft.com/library/default.asp?url=/library/en-us/dllproc/base/createthread.asp>.
356
357			worker->thread = ( HANDLE ) _beginthreadex_compat( NULL, 0, WorkerMain, worker, 0, &worker->id );
358			err = translate_errno( worker->thread, ( mStatus ) GetLastError(), kUnknownErr );
359			require_noerr( err, exit );
360
361			AddToTail( &gPoll.workers, worker );
362			gPoll.workerHandles[ gPoll.numWorkers++ ] = worker->stop;
363
364			WorkerRegisterSource( worker, source );
365		}
366	}
367
368exit:
369
370	if ( err && worker )
371	{
372		WorkerFree( worker );
373	}
374
375	return err;
376}
377
378
379mDNSlocal void
380PollUnregisterSource( PollSource *source )
381{
382	RemoveFromList( &gPoll.sources, source );
383	gPoll.numSources--;
384
385	WorkerUnregisterSource( source->worker, source );
386}
387
388
389mDNSlocal mStatus
390PollStartWorkers()
391{
392	Worker	*worker;
393	mStatus	err = mStatus_NoError;
394	BOOL	ok;
395
396	dlog( kDebugLevelChatty, DEBUG_NAME "starting workers\n" );
397
398	worker = gPoll.workers.Head;
399
400	while ( worker )
401	{
402		Worker *next = worker->next;
403
404		if ( worker->numSources == 1 )
405		{
406			PollRemoveWorker( worker );
407		}
408		else
409		{
410			dlog( kDebugLevelChatty, DEBUG_NAME "waking up worker\n" );
411
412			ok = SetEvent( worker->start );
413			err = translate_errno( ok, ( mStatus ) GetLastError(), kUnknownErr );
414			if ( err ) LogErr( err, "SetEvent()" );
415
416			if ( err )
417			{
418				PollRemoveWorker( worker );
419			}
420		}
421
422		worker = next;
423	}
424
425	err = mStatus_NoError;
426
427	return err;
428}
429
430
431mDNSlocal mStatus
432PollStopWorkers()
433{
434	DWORD	result;
435	Worker	*worker;
436	BOOL	ok;
437	mStatus	err = mStatus_NoError;
438
439	dlog( kDebugLevelChatty, DEBUG_NAME "stopping workers\n" );
440
441	ok = SetEvent( gPoll.wakeup );
442	err = translate_errno( ok, ( mStatus ) GetLastError(), kUnknownErr );
443	if ( err ) LogErr( err, "SetEvent()" );
444
445	// Wait For 5 seconds for all the workers to wake up
446
447	result = WaitForMultipleObjects( gPoll.numWorkers, gPoll.workerHandles, TRUE, 5000 );
448	err = translate_errno( ( result != WAIT_FAILED ), ( mStatus ) GetLastError(), kUnknownErr );
449	if ( err ) LogErr( err, "WaitForMultipleObjects()" );
450
451	ok = ResetEvent( gPoll.wakeup );
452	err = translate_errno( ok, ( mStatus ) GetLastError(), kUnknownErr );
453	if ( err ) LogErr( err, "ResetEvent()" );
454
455	for ( worker = gPoll.workers.Head; worker; worker = worker->next )
456	{
457		WorkerDispatch( worker );
458	}
459
460	err = mStatus_NoError;
461
462	return err;
463}
464
465
466mDNSlocal void
467PollRemoveWorker( Worker *worker )
468{
469	DWORD	result;
470	mStatus	err;
471	BOOL	ok;
472	DWORD	i;
473
474	dlog( kDebugLevelChatty, DEBUG_NAME "removing worker %d\n", worker->id );
475
476	RemoveFromList( &gPoll.workers, worker );
477
478	// Remove handle from gPoll.workerHandles
479
480	for ( i = 0; i < gPoll.numWorkers; i++ )
481	{
482		if ( gPoll.workerHandles[ i ] == worker->stop )
483		{
484			ShiftDown( gPoll.workerHandles, gPoll.numWorkers, sizeof( gPoll.workerHandles[ 0 ] ), i + 1 );
485			break;
486		}
487	}
488
489	worker->done = TRUE;
490	gPoll.numWorkers--;
491
492	// Cause the thread to exit.
493
494	ok = SetEvent( worker->start );
495	err = translate_errno( ok, ( OSStatus ) GetLastError(), kUnknownErr );
496	if ( err ) LogErr( err, "SetEvent()" );
497
498	result = WaitForSingleObject( worker->thread, 5000 );
499	err = translate_errno( result != WAIT_FAILED, ( OSStatus ) GetLastError(), kUnknownErr );
500	if ( err ) LogErr( err, "WaitForSingleObject()" );
501
502	if ( ( result == WAIT_FAILED ) || ( result == WAIT_TIMEOUT ) )
503	{
504		ok = TerminateThread( worker->thread, 0 );
505		err = translate_errno( ok, ( OSStatus ) GetLastError(), kUnknownErr );
506		if ( err ) LogErr( err, "TerminateThread()" );
507	}
508
509	CloseHandle( worker->thread );
510	worker->thread = NULL;
511
512	WorkerFree( worker );
513}
514
515
516mDNSlocal void
517WorkerRegisterSource( Worker *worker, PollSource *source )
518{
519	source->worker = worker;
520	worker->sources[ worker->numSources ] = source;
521	worker->handles[ worker->numSources ] = source->handle;
522	worker->numSources++;
523}
524
525
526mDNSlocal int
527WorkerSourceToIndex( Worker *worker, PollSource *source )
528{
529	int index;
530
531	for ( index = 0; index < ( int ) worker->numSources; index++ )
532	{
533		if ( worker->sources[ index ] == source )
534		{
535			break;
536		}
537	}
538
539	if ( index == ( int ) worker->numSources )
540	{
541		index = -1;
542	}
543
544	return index;
545}
546
547
548mDNSlocal void
549WorkerUnregisterSource( Worker *worker, PollSource *source )
550{
551	int sourceIndex = WorkerSourceToIndex( worker, source );
552	DWORD delta;
553
554	if ( sourceIndex == -1 )
555	{
556		LogMsg( "WorkerUnregisterSource: source not found in list" );
557		goto exit;
558	}
559
560	delta = ( worker->numSources - sourceIndex - 1 );
561
562	// If this source is not at the end of the list, then move memory
563
564	if ( delta > 0 )
565	{
566		ShiftDown( worker->sources, worker->numSources, sizeof( worker->sources[ 0 ] ), sourceIndex + 1 );
567		ShiftDown( worker->handles, worker->numSources, sizeof( worker->handles[ 0 ] ), sourceIndex + 1 );
568	}
569
570	worker->numSources--;
571
572exit:
573
574	return;
575}
576
577
578mDNSlocal void CALLBACK
579WorkerWakeupNotification( HANDLE event, void *context )
580{
581	DEBUG_UNUSED( event );
582	DEBUG_UNUSED( context );
583
584	dlog( kDebugLevelChatty, DEBUG_NAME "Worker thread wakeup\n" );
585}
586
587
588mDNSlocal void
589WorkerDispatch( Worker *worker )
590{
591	if ( worker->result == WAIT_FAILED )
592	{
593		/* What should we do here? */
594	}
595	else if ( worker->result == WAIT_TIMEOUT )
596	{
597		dlog( kDebugLevelChatty, DEBUG_NAME "timeout\n" );
598	}
599	else
600	{
601		DWORD		waitItemIndex = ( DWORD )( ( ( int ) worker->result ) - WAIT_OBJECT_0 );
602		PollSource	*source = NULL;
603
604		// Sanity check
605
606		if ( waitItemIndex >= worker->numSources )
607		{
608			LogMsg( "WorkerDispatch: waitItemIndex (%d) is >= numSources (%d)", waitItemIndex, worker->numSources );
609			goto exit;
610		}
611
612		source = worker->sources[ waitItemIndex ];
613
614		if ( source->socket != INVALID_SOCKET )
615		{
616			WSANETWORKEVENTS event;
617
618			if ( WSAEnumNetworkEvents( source->socket, source->handle, &event ) == 0 )
619			{
620				source->callback.socket( source->socket, &event, source->context );
621			}
622			else
623			{
624				source->callback.socket( source->socket, NULL, source->context );
625			}
626		}
627		else
628		{
629			source->callback.event( source->handle, source->context );
630		}
631	}
632
633exit:
634
635	return;
636}
637
638
639mDNSlocal mStatus
640WorkerInit( Worker *worker )
641{
642	PollSource *source = NULL;
643	mStatus err = mStatus_NoError;
644
645	require_action( worker, exit, err = mStatus_BadParamErr );
646
647	source = malloc( sizeof( PollSource ) );
648	require_action( source, exit, err = mStatus_NoMemoryErr );
649
650	source->socket = INVALID_SOCKET;
651	source->handle = gPoll.wakeup;
652	source->callback.event = WorkerWakeupNotification;
653	source->context = NULL;
654
655	WorkerRegisterSource( worker, source );
656
657exit:
658
659	return err;
660}
661
662
663mDNSlocal void
664WorkerFree( Worker *worker )
665{
666	if ( worker->start )
667	{
668		CloseHandle( worker->start );
669		worker->start = NULL;
670	}
671
672	if ( worker->stop )
673	{
674		CloseHandle( worker->stop );
675		worker->stop = NULL;
676	}
677
678	free( worker );
679}
680
681
682mDNSlocal unsigned WINAPI
683WorkerMain( LPVOID inParam )
684{
685	Worker *worker = ( Worker* ) inParam;
686	mStatus err = mStatus_NoError;
687
688	require_action( worker, exit, err = mStatus_BadParamErr );
689
690	dlog( kDebugLevelVerbose, DEBUG_NAME, "entering WorkerMain()\n" );
691
692	while ( TRUE )
693	{
694		DWORD	result;
695		BOOL	ok;
696
697		dlog( kDebugLevelChatty, DEBUG_NAME, "worker thread %d will wait on main loop\n", worker->id );
698
699		result = WaitForSingleObject( worker->start, INFINITE );
700		err = translate_errno( ( result != WAIT_FAILED ), ( mStatus ) GetLastError(), kUnknownErr );
701		if ( err ) { LogErr( err, "WaitForSingleObject()" ); break; }
702		if ( worker->done ) break;
703
704		dlog( kDebugLevelChatty, DEBUG_NAME "worker thread %d will wait on sockets\n", worker->id );
705
706		worker->result = WaitForMultipleObjects( worker->numSources, worker->handles, FALSE, INFINITE );
707		err = translate_errno( ( worker->result != WAIT_FAILED ), ( mStatus ) GetLastError(), kUnknownErr );
708		if ( err ) { LogErr( err, "WaitForMultipleObjects()" ); break; }
709
710		dlog( kDebugLevelChatty, DEBUG_NAME "worker thread %d did wait on sockets: %d\n", worker->id, worker->result );
711
712		ok = SetEvent( gPoll.wakeup );
713		err = translate_errno( ok, ( mStatus ) GetLastError(), kUnknownErr );
714		if ( err ) { LogErr( err, "SetEvent()" ); break; }
715
716		dlog( kDebugLevelChatty, DEBUG_NAME, "worker thread %d preparing to sleep\n", worker->id );
717
718		ok = SetEvent( worker->stop );
719		err = translate_errno( ok, ( mStatus ) GetLastError(), kUnknownErr );
720		if ( err ) { LogErr( err, "SetEvent()" ); break; }
721	}
722
723	dlog( kDebugLevelVerbose, DEBUG_NAME "exiting WorkerMain()\n" );
724
725exit:
726
727	return 0;
728}
729