1/*-
2 * Copyright (c) 2006 Robert N. M. Watson
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24 * SUCH DAMAGE.
25 *
26 * $FreeBSD$
27 */
28
29#include <sys/types.h>
30#include <sys/socket.h>
31#include <sys/stat.h>
32#include <sys/wait.h>
33
34#include <netinet/in.h>
35
36#include <err.h>
37#include <errno.h>
38#include <fcntl.h>
39#include <limits.h>
40#include <md5.h>
41#include <signal.h>
42#include <stdint.h>
43#include <stdio.h>
44#include <stdlib.h>
45#include <string.h>
46#include <unistd.h>
47
48/*
49 * Simple regression test for sendfile.  Creates a file sized at four pages
50 * and then proceeds to send it over a series of sockets, exercising a number
51 * of cases and performing limited validation.
52 */
53
54#define FAIL(msg)	{printf("# %s\n", msg); \
55			return (-1);}
56
57#define FAIL_ERR(msg)	{printf("# %s: %s\n", msg, strerror(errno)); \
58			return (-1);}
59
60#define	TEST_PORT	5678
61#define	TEST_MAGIC	0x4440f7bb
62#define	TEST_PAGES	4
63#define	TEST_SECONDS	30
64
65struct test_header {
66	uint32_t	th_magic;
67	uint32_t	th_header_length;
68	uint32_t	th_offset;
69	uint32_t	th_length;
70	char		th_md5[33];
71};
72
73struct sendfile_test {
74	uint32_t	hdr_length;
75	uint32_t	offset;
76	uint32_t	length;
77};
78
79int	file_fd;
80char	path[PATH_MAX];
81int	listen_socket;
82int	accept_socket;
83
84static int test_th(struct test_header *th, uint32_t *header_length,
85		uint32_t *offset, uint32_t *length);
86static void signal_alarm(int signum);
87static void setup_alarm(int seconds);
88static void cancel_alarm(void);
89static int receive_test(void);
90static void run_child(void);
91static int new_test_socket(int *connect_socket);
92static void init_th(struct test_header *th, uint32_t header_length,
93		uint32_t offset, uint32_t length);
94static int send_test(int connect_socket, struct sendfile_test);
95static void run_parent(void);
96static void cleanup(void);
97
98
99static int
100test_th(struct test_header *th, uint32_t *header_length, uint32_t *offset,
101		uint32_t *length)
102{
103
104	if (th->th_magic != htonl(TEST_MAGIC))
105		FAIL("magic number not found in header")
106	*header_length = ntohl(th->th_header_length);
107	*offset = ntohl(th->th_offset);
108	*length = ntohl(th->th_length);
109	return (0);
110}
111
112static void
113signal_alarm(int signum)
114{
115	(void)signum;
116
117	printf("# test timeout\n");
118
119	if (accept_socket > 0)
120		close(accept_socket);
121	if (listen_socket > 0)
122		close(listen_socket);
123
124	_exit(-1);
125}
126
127static void
128setup_alarm(int seconds)
129{
130	struct itimerval itv;
131	bzero(&itv, sizeof(itv));
132	(void)seconds;
133	itv.it_value.tv_sec = seconds;
134
135	signal(SIGALRM, signal_alarm);
136	setitimer(ITIMER_REAL, &itv, NULL);
137}
138
139static void
140cancel_alarm(void)
141{
142	struct itimerval itv;
143	bzero(&itv, sizeof(itv));
144	setitimer(ITIMER_REAL, &itv, NULL);
145}
146
147static int
148receive_test(void)
149{
150	uint32_t header_length, offset, length, counter;
151	struct test_header th;
152	ssize_t len;
153	char buf[10240];
154	MD5_CTX md5ctx;
155	char *rxmd5;
156
157	len = read(accept_socket, &th, sizeof(th));
158	if (len < 0 || (size_t)len < sizeof(th))
159		FAIL_ERR("read")
160
161	if (test_th(&th, &header_length, &offset, &length) != 0)
162		return (-1);
163
164	MD5Init(&md5ctx);
165
166	counter = 0;
167	while (1) {
168		len = read(accept_socket, buf, sizeof(buf));
169		if (len < 0 || len == 0)
170			break;
171		counter += len;
172		MD5Update(&md5ctx, buf, len);
173	}
174
175	rxmd5 = MD5End(&md5ctx, NULL);
176
177	if ((counter != header_length+length) ||
178			memcmp(th.th_md5, rxmd5, 33) != 0)
179		FAIL("receive length mismatch")
180
181	free(rxmd5);
182	return (0);
183}
184
185static void
186run_child(void)
187{
188	struct sockaddr_in sin;
189	int rc = 0;
190
191	listen_socket = socket(PF_INET, SOCK_STREAM, 0);
192	if (listen_socket < 0) {
193		printf("# socket: %s\n", strerror(errno));
194		rc = -1;
195	}
196
197	if (!rc) {
198		bzero(&sin, sizeof(sin));
199		sin.sin_len = sizeof(sin);
200		sin.sin_family = AF_INET;
201		sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
202		sin.sin_port = htons(TEST_PORT);
203
204		if (bind(listen_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0) {
205			printf("# bind: %s\n", strerror(errno));
206			rc = -1;
207		}
208	}
209
210	if (!rc && listen(listen_socket, -1) < 0) {
211		printf("# listen: %s\n", strerror(errno));
212		rc = -1;
213	}
214
215	if (!rc) {
216		accept_socket = accept(listen_socket, NULL, NULL);
217		setup_alarm(TEST_SECONDS);
218		if (receive_test() != 0)
219			rc = -1;
220	}
221
222	cancel_alarm();
223	if (accept_socket > 0)
224		close(accept_socket);
225	if (listen_socket > 0)
226		close(listen_socket);
227
228	_exit(rc);
229}
230
231static int
232new_test_socket(int *connect_socket)
233{
234	struct sockaddr_in sin;
235	int rc = 0;
236
237	*connect_socket = socket(PF_INET, SOCK_STREAM, 0);
238	if (*connect_socket < 0)
239		FAIL_ERR("socket")
240
241	bzero(&sin, sizeof(sin));
242	sin.sin_len = sizeof(sin);
243	sin.sin_family = AF_INET;
244	sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
245	sin.sin_port = htons(TEST_PORT);
246
247	if (connect(*connect_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0)
248		FAIL_ERR("connect")
249
250	return (rc);
251}
252
253static void
254init_th(struct test_header *th, uint32_t header_length, uint32_t offset,
255		uint32_t length)
256{
257	bzero(th, sizeof(*th));
258	th->th_magic = htonl(TEST_MAGIC);
259	th->th_header_length = htonl(header_length);
260	th->th_offset = htonl(offset);
261	th->th_length = htonl(length);
262
263	MD5FileChunk(path, th->th_md5, offset, length);
264}
265
266static int
267send_test(int connect_socket, struct sendfile_test test)
268{
269	struct test_header th;
270	struct sf_hdtr hdtr, *hdtrp;
271	struct iovec headers;
272	char *header;
273	ssize_t len;
274	int length;
275	off_t off;
276
277	len = lseek(file_fd, 0, SEEK_SET);
278	if (len != 0)
279		FAIL_ERR("lseek")
280
281	if (test.length == 0) {
282		struct stat st;
283		if (fstat(file_fd, &st) < 0)
284			FAIL_ERR("fstat")
285		length = st.st_size - test.offset;
286	}
287	else {
288		length = test.length;
289	}
290
291	init_th(&th, test.hdr_length, test.offset, length);
292
293	len = write(connect_socket, &th, sizeof(th));
294	if (len != sizeof(th))
295		return (-1);
296
297	if (test.hdr_length != 0) {
298		header = malloc(test.hdr_length);
299		if (header == NULL)
300			FAIL_ERR("malloc")
301
302		hdtrp = &hdtr;
303		bzero(&headers, sizeof(headers));
304		headers.iov_base = header;
305		headers.iov_len = test.hdr_length;
306		bzero(&hdtr, sizeof(hdtr));
307		hdtr.headers = &headers;
308		hdtr.hdr_cnt = 1;
309		hdtr.trailers = NULL;
310		hdtr.trl_cnt = 0;
311	} else {
312		hdtrp = NULL;
313		header = NULL;
314	}
315
316	if (sendfile(file_fd, connect_socket, test.offset, test.length,
317				hdtrp, &off, 0) < 0) {
318		if (header != NULL)
319			free(header);
320		FAIL_ERR("sendfile")
321	}
322
323	if (length == 0) {
324		struct stat sb;
325
326		if (fstat(file_fd, &sb) == 0)
327			length = sb.st_size - test.offset;
328	}
329
330	if (header != NULL)
331		free(header);
332
333	if (off != length)
334		FAIL("offset != length")
335
336	return (0);
337}
338
339static void
340run_parent(void)
341{
342	int connect_socket;
343	int status;
344	int test_num;
345	int pid;
346
347	const int pagesize = getpagesize();
348
349	struct sendfile_test tests[10] = {
350 		{ .hdr_length = 0, .offset = 0, .length = 1 },
351		{ .hdr_length = 0, .offset = 0, .length = pagesize },
352		{ .hdr_length = 0, .offset = 1, .length = 1 },
353		{ .hdr_length = 0, .offset = 1, .length = pagesize },
354		{ .hdr_length = 0, .offset = pagesize, .length = pagesize },
355		{ .hdr_length = 0, .offset = 0, .length = 2*pagesize },
356		{ .hdr_length = 0, .offset = 0, .length = 0 },
357		{ .hdr_length = 0, .offset = pagesize, .length = 0 },
358		{ .hdr_length = 0, .offset = 2*pagesize, .length = 0 },
359		{ .hdr_length = 0, .offset = TEST_PAGES*pagesize, .length = 0 }
360	};
361
362	printf("1..10\n");
363
364	for (test_num = 1; test_num <= 10; test_num++) {
365
366		pid = fork();
367		if (pid == -1) {
368			printf("not ok %d\n", test_num);
369			continue;
370		}
371
372		if (pid == 0)
373			run_child();
374
375		usleep(250000);
376
377		if (new_test_socket(&connect_socket) != 0) {
378			printf("not ok %d\n", test_num);
379			kill(pid, SIGALRM);
380			close(connect_socket);
381			continue;
382		}
383
384		if (send_test(connect_socket, tests[test_num-1]) != 0) {
385			printf("not ok %d\n", test_num);
386			kill(pid, SIGALRM);
387			close(connect_socket);
388			continue;
389		}
390
391		close(connect_socket);
392		if (waitpid(pid, &status, 0) == pid) {
393			if (WIFEXITED(status) && WEXITSTATUS(status) == 0)
394				printf("%s %d\n", "ok", test_num);
395			else
396				printf("%s %d\n", "not ok", test_num);
397		}
398		else {
399			printf("not ok %d\n", test_num);
400		}
401	}
402}
403
404static void
405cleanup(void)
406{
407	if (*path != '\0')
408		unlink(path);
409}
410
411int
412main(int argc, char *argv[])
413{
414	char *page_buffer;
415	int pagesize;
416	ssize_t len;
417
418	*path = '\0';
419
420	pagesize = getpagesize();
421	page_buffer = malloc(TEST_PAGES * pagesize);
422	if (page_buffer == NULL)
423		FAIL_ERR("malloc")
424	bzero(page_buffer, TEST_PAGES * pagesize);
425
426	if (argc == 1) {
427		snprintf(path, PATH_MAX, "/tmp/sendfile.XXXXXXXXXXXX");
428		file_fd = mkstemp(path);
429		if (file_fd == -1)
430			FAIL_ERR("mkstemp");
431	} else if (argc == 2) {
432		(void)strlcpy(path, argv[1], sizeof(path));
433		file_fd = open(path, O_CREAT | O_TRUNC | O_RDWR, 0600);
434		if (file_fd == -1)
435			FAIL_ERR("open");
436	} else {
437		FAIL("usage: sendfile [path]");
438	}
439
440	atexit(cleanup);
441
442	len = write(file_fd, page_buffer, TEST_PAGES * pagesize);
443	if (len < 0)
444		FAIL_ERR("write")
445
446	len = lseek(file_fd, 0, SEEK_SET);
447	if (len < 0)
448		FAIL_ERR("lseek")
449	if (len != 0)
450		FAIL("len != 0")
451
452	run_parent();
453	return (0);
454}
455