1/*-
2 * Copyright (c) 2006 Robert N. M. Watson
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24 * SUCH DAMAGE.
25 *
26 * $FreeBSD$
27 */
28
29#include <sys/types.h>
30#include <sys/socket.h>
31#include <sys/stat.h>
32#include <sys/wait.h>
33
34#include <netinet/in.h>
35
36#include <err.h>
37#include <errno.h>
38#include <fcntl.h>
39#include <limits.h>
40#include <md5.h>
41#include <signal.h>
42#include <stdint.h>
43#include <stdio.h>
44#include <stdlib.h>
45#include <string.h>
46#include <unistd.h>
47
48/*
49 * Simple regression test for sendfile.  Creates a file sized at four pages
50 * and then proceeds to send it over a series of sockets, exercising a number
51 * of cases and performing limited validation.
52 */
53
54#define FAIL(msg)	{printf("# %s\n", msg); \
55			return (-1);}
56
57#define FAIL_ERR(msg)	{printf("# %s: %s\n", msg, strerror(errno)); \
58			return (-1);}
59
60#define	TEST_PORT	5678
61#define	TEST_MAGIC	0x4440f7bb
62#define	TEST_PAGES	4
63#define	TEST_SECONDS	30
64
65struct test_header {
66	uint32_t	th_magic;
67	uint32_t	th_header_length;
68	uint32_t	th_offset;
69	uint32_t	th_length;
70	char		th_md5[33];
71};
72
73struct sendfile_test {
74	uint32_t	hdr_length;
75	uint32_t	offset;
76	uint32_t	length;
77	uint32_t	file_size;
78};
79
80static int	file_fd;
81static char	path[PATH_MAX];
82static int	listen_socket;
83static int	accept_socket;
84
85static int test_th(struct test_header *th, uint32_t *header_length,
86		uint32_t *offset, uint32_t *length);
87static void signal_alarm(int signum);
88static void setup_alarm(int seconds);
89static void cancel_alarm(void);
90static int receive_test(void);
91static void run_child(void);
92static int new_test_socket(int *connect_socket);
93static void init_th(struct test_header *th, uint32_t header_length,
94		uint32_t offset, uint32_t length);
95static int send_test(int connect_socket, struct sendfile_test);
96static int write_test_file(size_t file_size);
97static void run_parent(void);
98static void cleanup(void);
99
100
101static int
102test_th(struct test_header *th, uint32_t *header_length, uint32_t *offset,
103		uint32_t *length)
104{
105
106	if (th->th_magic != htonl(TEST_MAGIC))
107		FAIL("magic number not found in header")
108	*header_length = ntohl(th->th_header_length);
109	*offset = ntohl(th->th_offset);
110	*length = ntohl(th->th_length);
111	return (0);
112}
113
114static void
115signal_alarm(int signum)
116{
117	(void)signum;
118
119	printf("# test timeout\n");
120
121	if (accept_socket > 0)
122		close(accept_socket);
123	if (listen_socket > 0)
124		close(listen_socket);
125
126	_exit(-1);
127}
128
129static void
130setup_alarm(int seconds)
131{
132	struct itimerval itv;
133	bzero(&itv, sizeof(itv));
134	(void)seconds;
135	itv.it_value.tv_sec = seconds;
136
137	signal(SIGALRM, signal_alarm);
138	setitimer(ITIMER_REAL, &itv, NULL);
139}
140
141static void
142cancel_alarm(void)
143{
144	struct itimerval itv;
145	bzero(&itv, sizeof(itv));
146	setitimer(ITIMER_REAL, &itv, NULL);
147}
148
149static int
150receive_test(void)
151{
152	uint32_t header_length, offset, length, counter;
153	struct test_header th;
154	ssize_t len;
155	char buf[10240];
156	MD5_CTX md5ctx;
157	char *rxmd5;
158
159	len = read(accept_socket, &th, sizeof(th));
160	if (len < 0 || (size_t)len < sizeof(th))
161		FAIL_ERR("read")
162
163	if (test_th(&th, &header_length, &offset, &length) != 0)
164		return (-1);
165
166	MD5Init(&md5ctx);
167
168	counter = 0;
169	while (1) {
170		len = read(accept_socket, buf, sizeof(buf));
171		if (len < 0 || len == 0)
172			break;
173		counter += len;
174		MD5Update(&md5ctx, buf, len);
175	}
176
177	rxmd5 = MD5End(&md5ctx, NULL);
178
179	if ((counter != header_length+length) ||
180			memcmp(th.th_md5, rxmd5, 33) != 0)
181		FAIL("receive length mismatch")
182
183	free(rxmd5);
184	return (0);
185}
186
187static void
188run_child(void)
189{
190	struct sockaddr_in sin;
191	int rc = 0;
192
193	listen_socket = socket(PF_INET, SOCK_STREAM, 0);
194	if (listen_socket < 0) {
195		printf("# socket: %s\n", strerror(errno));
196		rc = -1;
197	}
198
199	if (!rc) {
200		bzero(&sin, sizeof(sin));
201		sin.sin_len = sizeof(sin);
202		sin.sin_family = AF_INET;
203		sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
204		sin.sin_port = htons(TEST_PORT);
205
206		if (bind(listen_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0) {
207			printf("# bind: %s\n", strerror(errno));
208			rc = -1;
209		}
210	}
211
212	if (!rc && listen(listen_socket, -1) < 0) {
213		printf("# listen: %s\n", strerror(errno));
214		rc = -1;
215	}
216
217	if (!rc) {
218		accept_socket = accept(listen_socket, NULL, NULL);
219		setup_alarm(TEST_SECONDS);
220		if (receive_test() != 0)
221			rc = -1;
222	}
223
224	cancel_alarm();
225	if (accept_socket > 0)
226		close(accept_socket);
227	if (listen_socket > 0)
228		close(listen_socket);
229
230	_exit(rc);
231}
232
233static int
234new_test_socket(int *connect_socket)
235{
236	struct sockaddr_in sin;
237	int rc = 0;
238
239	*connect_socket = socket(PF_INET, SOCK_STREAM, 0);
240	if (*connect_socket < 0)
241		FAIL_ERR("socket")
242
243	bzero(&sin, sizeof(sin));
244	sin.sin_len = sizeof(sin);
245	sin.sin_family = AF_INET;
246	sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
247	sin.sin_port = htons(TEST_PORT);
248
249	if (connect(*connect_socket, (struct sockaddr *)&sin, sizeof(sin)) < 0)
250		FAIL_ERR("connect")
251
252	return (rc);
253}
254
255static void
256init_th(struct test_header *th, uint32_t header_length, uint32_t offset,
257		uint32_t length)
258{
259	bzero(th, sizeof(*th));
260	th->th_magic = htonl(TEST_MAGIC);
261	th->th_header_length = htonl(header_length);
262	th->th_offset = htonl(offset);
263	th->th_length = htonl(length);
264
265	MD5FileChunk(path, th->th_md5, offset, length);
266}
267
268static int
269send_test(int connect_socket, struct sendfile_test test)
270{
271	struct test_header th;
272	struct sf_hdtr hdtr, *hdtrp;
273	struct iovec headers;
274	char *header;
275	ssize_t len;
276	int length;
277	off_t off;
278
279	len = lseek(file_fd, 0, SEEK_SET);
280	if (len != 0)
281		FAIL_ERR("lseek")
282
283	struct stat st;
284	if (fstat(file_fd, &st) < 0)
285		FAIL_ERR("fstat")
286	length = st.st_size - test.offset;
287	if (test.length > 0 && test.length < (uint32_t)length)
288		length = test.length;
289
290	init_th(&th, test.hdr_length, test.offset, length);
291
292	len = write(connect_socket, &th, sizeof(th));
293	if (len != sizeof(th))
294		return (-1);
295
296	if (test.hdr_length != 0) {
297		header = malloc(test.hdr_length);
298		if (header == NULL)
299			FAIL_ERR("malloc")
300
301		hdtrp = &hdtr;
302		bzero(&headers, sizeof(headers));
303		headers.iov_base = header;
304		headers.iov_len = test.hdr_length;
305		bzero(&hdtr, sizeof(hdtr));
306		hdtr.headers = &headers;
307		hdtr.hdr_cnt = 1;
308		hdtr.trailers = NULL;
309		hdtr.trl_cnt = 0;
310	} else {
311		hdtrp = NULL;
312		header = NULL;
313	}
314
315	if (sendfile(file_fd, connect_socket, test.offset, test.length,
316				hdtrp, &off, 0) < 0) {
317		if (header != NULL)
318			free(header);
319		FAIL_ERR("sendfile")
320	}
321
322	if (length == 0) {
323		struct stat sb;
324
325		if (fstat(file_fd, &sb) == 0)
326			length = sb.st_size - test.offset;
327	}
328
329	if (header != NULL)
330		free(header);
331
332	if (off != length)
333		FAIL("offset != length")
334
335	return (0);
336}
337
338static int
339write_test_file(size_t file_size)
340{
341	char *page_buffer;
342	ssize_t len;
343	static size_t current_file_size = 0;
344
345	if (file_size == current_file_size)
346		return (0);
347	else if (file_size < current_file_size) {
348		if (ftruncate(file_fd, file_size) != 0)
349			FAIL_ERR("ftruncate");
350		current_file_size = file_size;
351		return (0);
352	}
353
354	page_buffer = malloc(file_size);
355	if (page_buffer == NULL)
356		FAIL_ERR("malloc")
357	bzero(page_buffer, file_size);
358
359	len = write(file_fd, page_buffer, file_size);
360	if (len < 0)
361		FAIL_ERR("write")
362
363	len = lseek(file_fd, 0, SEEK_SET);
364	if (len < 0)
365		FAIL_ERR("lseek")
366	if (len != 0)
367		FAIL("len != 0")
368
369	free(page_buffer);
370	current_file_size = file_size;
371	return (0);
372}
373
374static void
375run_parent(void)
376{
377	int connect_socket;
378	int status;
379	int test_num;
380	int test_count;
381	int pid;
382	size_t desired_file_size = 0;
383
384	const int pagesize = getpagesize();
385
386	struct sendfile_test tests[] = {
387 		{ .hdr_length = 0, .offset = 0, .length = 1 },
388		{ .hdr_length = 0, .offset = 0, .length = pagesize },
389		{ .hdr_length = 0, .offset = 1, .length = 1 },
390		{ .hdr_length = 0, .offset = 1, .length = pagesize },
391		{ .hdr_length = 0, .offset = pagesize, .length = pagesize },
392		{ .hdr_length = 0, .offset = 0, .length = 2*pagesize },
393		{ .hdr_length = 0, .offset = 0, .length = 0 },
394		{ .hdr_length = 0, .offset = pagesize, .length = 0 },
395		{ .hdr_length = 0, .offset = 2*pagesize, .length = 0 },
396		{ .hdr_length = 0, .offset = TEST_PAGES*pagesize, .length = 0 },
397		{ .hdr_length = 0, .offset = 0, .length = pagesize,
398		    .file_size = 1 }
399	};
400
401	test_count = sizeof(tests) / sizeof(tests[0]);
402	printf("1..%d\n", test_count);
403
404	for (test_num = 1; test_num <= test_count; test_num++) {
405
406		desired_file_size = tests[test_num - 1].file_size;
407		if (desired_file_size == 0)
408			desired_file_size = TEST_PAGES * pagesize;
409		if (write_test_file(desired_file_size) != 0) {
410			printf("not ok %d\n", test_num);
411			continue;
412		}
413
414		pid = fork();
415		if (pid == -1) {
416			printf("not ok %d\n", test_num);
417			continue;
418		}
419
420		if (pid == 0)
421			run_child();
422
423		usleep(250000);
424
425		if (new_test_socket(&connect_socket) != 0) {
426			printf("not ok %d\n", test_num);
427			kill(pid, SIGALRM);
428			close(connect_socket);
429			continue;
430		}
431
432		if (send_test(connect_socket, tests[test_num-1]) != 0) {
433			printf("not ok %d\n", test_num);
434			kill(pid, SIGALRM);
435			close(connect_socket);
436			continue;
437		}
438
439		close(connect_socket);
440		if (waitpid(pid, &status, 0) == pid) {
441			if (WIFEXITED(status) && WEXITSTATUS(status) == 0)
442				printf("%s %d\n", "ok", test_num);
443			else
444				printf("%s %d\n", "not ok", test_num);
445		}
446		else {
447			printf("not ok %d\n", test_num);
448		}
449	}
450}
451
452static void
453cleanup(void)
454{
455
456	unlink(path);
457}
458
459int
460main(int argc, char *argv[])
461{
462
463	path[0] = '\0';
464
465	if (argc == 1) {
466		snprintf(path, sizeof(path), "sendfile.XXXXXXXXXXXX");
467		file_fd = mkstemp(path);
468		if (file_fd == -1)
469			FAIL_ERR("mkstemp");
470	} else if (argc == 2) {
471		(void)strlcpy(path, argv[1], sizeof(path));
472		file_fd = open(path, O_CREAT | O_TRUNC | O_RDWR, 0600);
473		if (file_fd == -1)
474			FAIL_ERR("open");
475	} else {
476		FAIL("usage: sendfile [path]");
477	}
478
479	atexit(cleanup);
480
481	run_parent();
482	return (0);
483}
484