1/*
2 * Copyright 2008, Oliver Tappe, zooey@hirschkaefer.de.
3 * Distributed under the terms of the MIT license.
4 */
5
6
7#include <Message.h>
8#include <NetEndpoint.h>
9
10#include <errno.h>
11#include <netinet/in.h>
12#include <stdio.h>
13#include <stdlib.h>
14#include <string.h>
15#include <sys/wait.h>
16
17
18static BNetAddress serverAddr("127.0.0.1", 1234);
19static BNetAddress clientAddr("127.0.0.1", 51234);
20
21
22static int problemCount = 0;
23
24
25void
26checkAddrsAreEqual(const BNetAddress& na1, const BNetAddress& na2,
27	const char* fmt)
28{
29	in_addr addr1, addr2;
30	unsigned short port1, port2;
31	na1.GetAddr(addr1, &port1);
32	na2.GetAddr(addr2, &port2);
33	if (addr1.s_addr == addr2.s_addr && port1 == port2)
34		return;
35	fprintf(stderr, fmt, addr1.s_addr, port1, addr2.s_addr, port2);
36	exit(1);
37}
38
39
40void
41checkArchive(const BNetEndpoint ne, int32 protocol,
42	const BNetAddress& localNetAddress, const BNetAddress& remoteNetAddress)
43{
44	in_addr localAddr, remoteAddr;
45	unsigned short localPort, remotePort;
46	localNetAddress.GetAddr(localAddr, &localPort);
47	remoteNetAddress.GetAddr(remoteAddr, &remotePort);
48
49	BMessage archive;
50	status_t status = ne.Archive(&archive);
51	if (status != B_OK) {
52		fprintf(stderr, "Archive() failed - %lx:%s\n", status,
53			strerror(status));
54		problemCount++;
55		exit(1);
56	}
57	const char* arcClass;
58	if (archive.FindString("class", &arcClass) != B_OK) {
59		fprintf(stderr, "'class' not found in archive\n");
60		problemCount++;
61		exit(1);
62	}
63	if (strcmp(arcClass, "BNetEndpoint") != 0) {
64		fprintf(stderr, "expected 'class' to be 'BNetEndpoint' - is '%s'\n",
65			arcClass);
66		problemCount++;
67		exit(1);
68	}
69
70	if (ne.LocalAddr().InitCheck() == B_OK) {
71		int32 arcAddr;
72		if (archive.FindInt32("_BNetEndpoint_addr_addr", &arcAddr) != B_OK) {
73			fprintf(stderr, "'_BNetEndpoint_addr_addr' not found in archive\n");
74			problemCount++;
75			exit(1);
76		}
77		if ((uint32)localAddr.s_addr != (uint32)arcAddr) {
78			fprintf(stderr,
79				"expected '_BNetEndpoint_addr_addr' to be %x - is %x\n",
80				localAddr.s_addr, (unsigned int)arcAddr);
81			problemCount++;
82			exit(1);
83		}
84		int16 arcPort;
85		if (archive.FindInt16("_BNetEndpoint_addr_port", &arcPort) != B_OK) {
86			fprintf(stderr, "'_BNetEndpoint_addr_port' not found in archive\n");
87			problemCount++;
88			exit(1);
89		}
90		if ((uint16)localPort != (uint16)arcPort) {
91			fprintf(stderr,
92				"expected '_BNetEndpoint_addr_port' to be %d - is %d\n",
93				localPort, (int)arcPort);
94			problemCount++;
95			exit(1);
96		}
97	}
98
99	if (ne.RemoteAddr().InitCheck() == B_OK) {
100		int32 arcAddr;
101		if (archive.FindInt32("_BNetEndpoint_peer_addr", &arcAddr) != B_OK) {
102			fprintf(stderr, "'_BNetEndpoint_peer_addr' not found in archive\n");
103			problemCount++;
104			exit(1);
105		}
106		if ((uint32)remoteAddr.s_addr != (uint32)arcAddr) {
107			fprintf(stderr,
108				"expected '_BNetEndpoint_peer_addr' to be %x - is %x\n",
109				remoteAddr.s_addr, (unsigned int)arcAddr);
110			problemCount++;
111			exit(1);
112		}
113		int16 arcPort;
114		if (archive.FindInt16("_BNetEndpoint_peer_port", &arcPort) != B_OK) {
115			fprintf(stderr, "'_BNetEndpoint_peer_port' not found in archive\n");
116			problemCount++;
117			exit(1);
118		}
119		if ((uint16)remotePort != (uint16)arcPort) {
120			fprintf(stderr,
121				"expected '_BNetEndpoint_peer_port' to be %u - is %u\n",
122				remotePort, (unsigned short)arcPort);
123			problemCount++;
124			exit(1);
125		}
126	}
127
128	int64 arcTimeout;
129	if (archive.FindInt64("_BNetEndpoint_timeout", &arcTimeout) != B_OK) {
130		fprintf(stderr, "'_BNetEndpoint_timeout' not found in archive\n");
131		problemCount++;
132		exit(1);
133	}
134	if (arcTimeout != B_INFINITE_TIMEOUT) {
135		fprintf(stderr,
136			"expected '_BNetEndpoint_timeout' to be %llu - is %llu\n",
137			B_INFINITE_TIMEOUT, (uint64)arcTimeout);
138		problemCount++;
139		exit(1);
140	}
141
142	int32 arcProtocol;
143	if (archive.FindInt32("_BNetEndpoint_proto", &arcProtocol) != B_OK) {
144		fprintf(stderr, "'_BNetEndpoint_proto' not found in archive\n");
145		problemCount++;
146		exit(1);
147	}
148	if (arcProtocol != protocol) {
149		fprintf(stderr, "expected '_BNetEndpoint_proto' to be %d - is %d\n",
150			(int)protocol, (int)arcProtocol);
151		problemCount++;
152		exit(1);
153	}
154
155	BNetEndpoint* clone
156		= dynamic_cast<BNetEndpoint *>(BNetEndpoint::Instantiate(&archive));
157	if (!clone) {
158		fprintf(stderr, "unable to instantiate endpoint from archive\n");
159		problemCount++;
160		exit(1);
161	}
162	delete clone;
163}
164
165void testServer(thread_id clientThread)
166{
167	char buf[1];
168
169	// check simple UDP "connection"
170	BNetEndpoint server(SOCK_DGRAM);
171	for(int i=0; i < 2; ++i) {
172		status_t status = server.Bind(serverAddr);
173		if (status != B_OK) {
174			fprintf(stderr, "Bind() failed in testServer - %s\n",
175				strerror(status));
176			problemCount++;
177			exit(1);
178		}
179
180		checkAddrsAreEqual(server.LocalAddr(), serverAddr,
181			"LocalAddr() doesn't match serverAddr\n");
182
183		if (i == 0)
184			resume_thread(clientThread);
185
186		BNetAddress remoteAddr;
187		status = server.ReceiveFrom(buf, 1, remoteAddr, 0);
188		if (status < B_OK) {
189			fprintf(stderr, "ReceiveFrom() failed in testServer - %s\n",
190				strerror(status));
191			problemCount++;
192			exit(1);
193		}
194
195		if (buf[0] != 'U') {
196			fprintf(stderr, "expected to receive %c but got %c\n", 'U', buf[0]);
197			problemCount++;
198			exit(1);
199		}
200
201		checkAddrsAreEqual(remoteAddr, clientAddr,
202			"remoteAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
203
204		checkArchive(server, SOCK_DGRAM, serverAddr, clientAddr);
205
206		server.Close();
207	}
208
209	// now switch to TCP and try again
210	server.SetProtocol(SOCK_STREAM);
211	status_t status = server.Bind(serverAddr);
212	if (status != B_OK) {
213		fprintf(stderr, "Bind() failed in testServer - %s\n",
214			strerror(status));
215		problemCount++;
216		exit(1);
217	}
218
219	checkAddrsAreEqual(server.LocalAddr(), serverAddr,
220		"LocalAddr() doesn't match serverAddr\n");
221
222	status = server.Listen();
223	BNetEndpoint* acceptedConn = server.Accept();
224	if (acceptedConn == NULL) {
225		fprintf(stderr, "Accept() failed in testServer\n");
226		problemCount++;
227		exit(1);
228	}
229
230	const BNetAddress& remoteAddr = acceptedConn->RemoteAddr();
231	checkAddrsAreEqual(remoteAddr, clientAddr,
232		"remoteAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
233
234	status = acceptedConn->Receive(buf, 1);
235	if (status < B_OK) {
236		fprintf(stderr, "Receive() failed in testServer - %s\n",
237			strerror(status));
238		problemCount++;
239		exit(1);
240	}
241	delete acceptedConn;
242
243	if (buf[0] != 'T') {
244		fprintf(stderr, "expected to receive %c but got %c\n", 'T', buf[0]);
245		problemCount++;
246		exit(1);
247	}
248
249	checkArchive(server, SOCK_STREAM, serverAddr, clientAddr);
250
251	server.Close();
252}
253
254
255int32 testClient(void *)
256{
257	BNetEndpoint client(SOCK_DGRAM);
258	printf("testing udp...\n");
259	for(int i=0; i < 2; ++i) {
260		status_t status = client.Bind(clientAddr);
261		if (status != B_OK) {
262			fprintf(stderr, "Bind() failed in testClient - %s\n",
263				strerror(status));
264			problemCount++;
265			exit(1);
266		}
267
268		checkAddrsAreEqual(client.LocalAddr(), clientAddr,
269			"LocalAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
270
271		status = client.SendTo("U", 1, serverAddr, 0);
272		if (status < B_OK) {
273			fprintf(stderr, "SendTo() failed in testClient - %s\n",
274				strerror(status));
275			problemCount++;
276			exit(1);
277		}
278
279		checkArchive(client, SOCK_DGRAM, clientAddr, serverAddr);
280
281		sleep(1);
282
283		client.Close();
284	}
285
286	sleep(1);
287
288	printf("testing tcp...\n");
289	// now switch to TCP and try again
290	client.SetProtocol(SOCK_STREAM);
291	status_t status = client.Bind(clientAddr);
292	if (status != B_OK) {
293		fprintf(stderr, "Bind() failed in testClient - %s\n",
294			strerror(status));
295		problemCount++;
296		exit(1);
297	}
298
299	checkAddrsAreEqual(client.LocalAddr(), clientAddr,
300		"LocalAddr(%x:%d) doesn't match clientAddr(%x:%d)\n");
301
302	status = client.Connect(serverAddr);
303	if (status < B_OK) {
304		fprintf(stderr, "Connect() failed in testClient - %s\n",
305			strerror(status));
306		problemCount++;
307		exit(1);
308	}
309	status = client.Send("T", 1);
310	if (status < B_OK) {
311		fprintf(stderr, "Send() failed in testClient - %s\n",
312			strerror(status));
313		problemCount++;
314		exit(1);
315	}
316
317	checkArchive(client, SOCK_STREAM, clientAddr, serverAddr);
318
319	client.Close();
320
321	return B_OK;
322}
323
324
325int
326main(int argc, const char* const* argv)
327{
328	BNetEndpoint dummy(SOCK_DGRAM);
329	if (sizeof(dummy) != 208) {
330		fprintf(stderr, "expected sizeof(netEndpoint) to be 208 - is %ld\n",
331			sizeof(dummy));
332		exit(1);
333	}
334	dummy.Close();
335
336	// start thread for client
337	thread_id tid = spawn_thread(testClient, "client", B_NORMAL_PRIORITY, NULL);
338	if (tid < 0) {
339		fprintf(stderr, "spawn_thread() failed: %s\n", strerror(tid));
340		exit(1);
341	}
342
343	testServer(tid);
344
345	status_t clientStatus;
346	wait_for_thread(tid, &clientStatus);
347
348	if (!problemCount)
349		printf("Everything went fine.\n");
350
351	return 0;
352}
353