preload.c revision 331769
1/*
2 * Copyright (c) 2011-2012 Intel Corporation.  All rights reserved.
3 *
4 * This software is available to you under a choice of one of two
5 * licenses.  You may choose to be licensed under the terms of the GNU
6 * General Public License (GPL) Version 2, available from the file
7 * COPYING in the main directory of this source tree, or the
8 * OpenIB.org BSD license below:
9 *
10 *     Redistribution and use in source and binary forms, with or
11 *     without modification, are permitted provided that the following
12 *     conditions are met:
13 *
14 *      - Redistributions of source code must retain the above
15 *        copyright notice, this list of conditions and the following
16 *        disclaimer.
17 *
18 *      - Redistributions in binary form must reproduce the above
19 *        copyright notice, this list of conditions and the following
20 *        disclaimer in the documentation and/or other materials
21 *        provided with the distribution.
22 *
23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 * SOFTWARE.
31 *
32 */
33#define _GNU_SOURCE
34#include <config.h>
35
36#include <sys/types.h>
37#include <sys/socket.h>
38#include <sys/uio.h>
39#include <sys/stat.h>
40#include <sys/mman.h>
41#include <stdarg.h>
42#include <dlfcn.h>
43#include <netdb.h>
44#include <unistd.h>
45#include <fcntl.h>
46#include <string.h>
47#include <netinet/tcp.h>
48#include <unistd.h>
49#include <semaphore.h>
50#include <ctype.h>
51#include <stdlib.h>
52#include <stdio.h>
53
54#include <rdma/rdma_cma.h>
55#include <rdma/rdma_verbs.h>
56#include <rdma/rsocket.h>
57#include "cma.h"
58#include "indexer.h"
59
60struct socket_calls {
61	int (*socket)(int domain, int type, int protocol);
62	int (*bind)(int socket, const struct sockaddr *addr, socklen_t addrlen);
63	int (*listen)(int socket, int backlog);
64	int (*accept)(int socket, struct sockaddr *addr, socklen_t *addrlen);
65	int (*connect)(int socket, const struct sockaddr *addr, socklen_t addrlen);
66	ssize_t (*recv)(int socket, void *buf, size_t len, int flags);
67	ssize_t (*recvfrom)(int socket, void *buf, size_t len, int flags,
68			    struct sockaddr *src_addr, socklen_t *addrlen);
69	ssize_t (*recvmsg)(int socket, struct msghdr *msg, int flags);
70	ssize_t (*read)(int socket, void *buf, size_t count);
71	ssize_t (*readv)(int socket, const struct iovec *iov, int iovcnt);
72	ssize_t (*send)(int socket, const void *buf, size_t len, int flags);
73	ssize_t (*sendto)(int socket, const void *buf, size_t len, int flags,
74			  const struct sockaddr *dest_addr, socklen_t addrlen);
75	ssize_t (*sendmsg)(int socket, const struct msghdr *msg, int flags);
76	ssize_t (*write)(int socket, const void *buf, size_t count);
77	ssize_t (*writev)(int socket, const struct iovec *iov, int iovcnt);
78	int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout);
79	int (*shutdown)(int socket, int how);
80	int (*close)(int socket);
81	int (*getpeername)(int socket, struct sockaddr *addr, socklen_t *addrlen);
82	int (*getsockname)(int socket, struct sockaddr *addr, socklen_t *addrlen);
83	int (*setsockopt)(int socket, int level, int optname,
84			  const void *optval, socklen_t optlen);
85	int (*getsockopt)(int socket, int level, int optname,
86			  void *optval, socklen_t *optlen);
87	int (*fcntl)(int socket, int cmd, ... /* arg */);
88	int (*dup2)(int oldfd, int newfd);
89	ssize_t (*sendfile)(int out_fd, int in_fd, off_t *offset, size_t count);
90	int (*fxstat)(int ver, int fd, struct stat *buf);
91};
92
93static struct socket_calls real;
94static struct socket_calls rs;
95
96static struct index_map idm;
97static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
98
99static int sq_size;
100static int rq_size;
101static int sq_inline;
102static int fork_support;
103
104enum fd_type {
105	fd_normal,
106	fd_rsocket
107};
108
109enum fd_fork_state {
110	fd_ready,
111	fd_fork,
112	fd_fork_listen,
113	fd_fork_active,
114	fd_fork_passive
115};
116
117struct fd_info {
118	enum fd_type type;
119	enum fd_fork_state state;
120	int fd;
121	int dupfd;
122	_Atomic(int) refcnt;
123};
124
125struct config_entry {
126	char *name;
127	int domain;
128	int type;
129	int protocol;
130};
131
132static struct config_entry *config;
133static int config_cnt;
134
135static void free_config(void)
136{
137	while (config_cnt)
138		free(config[--config_cnt].name);
139
140	free(config);
141}
142
143/*
144 * Config file format:
145 * # Starting '#' indicates comment
146 * # wild card values are supported using '*'
147 * # domain - *, INET, INET6, IB
148 * # type - *, STREAM, DGRAM
149 * # protocol - *, TCP, UDP
150 * program_name domain type protocol
151 */
152static void scan_config(void)
153{
154	struct config_entry *new_config;
155	FILE *fp;
156	char line[120], prog[64], dom[16], type[16], proto[16];
157
158	fp = fopen(RS_CONF_DIR "/preload_config", "r");
159	if (!fp)
160		return;
161
162	while (fgets(line, sizeof(line), fp)) {
163		if (line[0] == '#')
164			continue;
165
166		if (sscanf(line, "%64s%16s%16s%16s", prog, dom, type, proto) != 4)
167			continue;
168
169		new_config = realloc(config, (config_cnt + 1) *
170					     sizeof(struct config_entry));
171		if (!new_config)
172			break;
173
174		config = new_config;
175		memset(&config[config_cnt], 0, sizeof(struct config_entry));
176
177		if (!strcasecmp(dom, "INET") ||
178		    !strcasecmp(dom, "AF_INET") ||
179		    !strcasecmp(dom, "PF_INET")) {
180			config[config_cnt].domain = AF_INET;
181		} else if (!strcasecmp(dom, "INET6") ||
182			   !strcasecmp(dom, "AF_INET6") ||
183			   !strcasecmp(dom, "PF_INET6")) {
184			config[config_cnt].domain = AF_INET6;
185		} else if (!strcasecmp(dom, "IB") ||
186			   !strcasecmp(dom, "AF_IB") ||
187			   !strcasecmp(dom, "PF_IB")) {
188			config[config_cnt].domain = AF_IB;
189		} else if (strcmp(dom, "*")) {
190			continue;
191		}
192
193		if (!strcasecmp(type, "STREAM") ||
194		    !strcasecmp(type, "SOCK_STREAM")) {
195			config[config_cnt].type = SOCK_STREAM;
196		} else if (!strcasecmp(type, "DGRAM") ||
197			   !strcasecmp(type, "SOCK_DGRAM")) {
198			config[config_cnt].type = SOCK_DGRAM;
199		} else if (strcmp(type, "*")) {
200			continue;
201		}
202
203		if (!strcasecmp(proto, "TCP") ||
204		    !strcasecmp(proto, "IPPROTO_TCP")) {
205			config[config_cnt].protocol = IPPROTO_TCP;
206		} else if (!strcasecmp(proto, "UDP") ||
207			   !strcasecmp(proto, "IPPROTO_UDP")) {
208			config[config_cnt].protocol = IPPROTO_UDP;
209		} else if (strcmp(proto, "*")) {
210			continue;
211		}
212
213		if (strcmp(prog, "*")) {
214		    if (!(config[config_cnt].name = strdup(prog)))
215			    continue;
216		}
217
218		config_cnt++;
219	}
220
221	fclose(fp);
222	if (config_cnt)
223		atexit(free_config);
224}
225
226static int intercept_socket(int domain, int type, int protocol)
227{
228	int i;
229
230	if (!config_cnt)
231		return 1;
232
233	if (!protocol) {
234		if (type == SOCK_STREAM)
235			protocol = IPPROTO_TCP;
236		else if (type == SOCK_DGRAM)
237			protocol = IPPROTO_UDP;
238	}
239
240	for (i = 0; i < config_cnt; i++) {
241		if ((!config[i].name ||
242		     !strncasecmp(config[i].name, program_invocation_short_name,
243				  strlen(config[i].name))) &&
244		    (!config[i].domain || config[i].domain == domain) &&
245		    (!config[i].type || config[i].type == type) &&
246		    (!config[i].protocol || config[i].protocol == protocol))
247			return 1;
248	}
249
250	return 0;
251}
252
253static int fd_open(void)
254{
255	struct fd_info *fdi;
256	int ret, index;
257
258	fdi = calloc(1, sizeof(*fdi));
259	if (!fdi)
260		return ERR(ENOMEM);
261
262	index = open("/dev/null", O_RDONLY);
263	if (index < 0) {
264		ret = index;
265		goto err1;
266	}
267
268	fdi->dupfd = -1;
269	atomic_store(&fdi->refcnt, 1);
270	pthread_mutex_lock(&mut);
271	ret = idm_set(&idm, index, fdi);
272	pthread_mutex_unlock(&mut);
273	if (ret < 0)
274		goto err2;
275
276	return index;
277
278err2:
279	real.close(index);
280err1:
281	free(fdi);
282	return ret;
283}
284
285static void fd_store(int index, int fd, enum fd_type type, enum fd_fork_state state)
286{
287	struct fd_info *fdi;
288
289	fdi = idm_at(&idm, index);
290	fdi->fd = fd;
291	fdi->type = type;
292	fdi->state = state;
293}
294
295static inline enum fd_type fd_get(int index, int *fd)
296{
297	struct fd_info *fdi;
298
299	fdi = idm_lookup(&idm, index);
300	if (fdi) {
301		*fd = fdi->fd;
302		return fdi->type;
303
304	} else {
305		*fd = index;
306		return fd_normal;
307	}
308}
309
310static inline int fd_getd(int index)
311{
312	struct fd_info *fdi;
313
314	fdi = idm_lookup(&idm, index);
315	return fdi ? fdi->fd : index;
316}
317
318static inline enum fd_fork_state fd_gets(int index)
319{
320	struct fd_info *fdi;
321
322	fdi = idm_lookup(&idm, index);
323	return fdi ? fdi->state : fd_ready;
324}
325
326static inline enum fd_type fd_gett(int index)
327{
328	struct fd_info *fdi;
329
330	fdi = idm_lookup(&idm, index);
331	return fdi ? fdi->type : fd_normal;
332}
333
334static enum fd_type fd_close(int index, int *fd)
335{
336	struct fd_info *fdi;
337	enum fd_type type;
338
339	fdi = idm_lookup(&idm, index);
340	if (fdi) {
341		idm_clear(&idm, index);
342		*fd = fdi->fd;
343		type = fdi->type;
344		real.close(index);
345		free(fdi);
346	} else {
347		*fd = index;
348		type = fd_normal;
349	}
350	return type;
351}
352
353static void getenv_options(void)
354{
355	char *var;
356
357	var = getenv("RS_SQ_SIZE");
358	if (var)
359		sq_size = atoi(var);
360
361	var = getenv("RS_RQ_SIZE");
362	if (var)
363		rq_size = atoi(var);
364
365	var = getenv("RS_INLINE");
366	if (var)
367		sq_inline = atoi(var);
368
369	var = getenv("RDMAV_FORK_SAFE");
370	if (var)
371		fork_support = atoi(var);
372}
373
374static void init_preload(void)
375{
376	static int init;
377
378	/* Quick check without lock */
379	if (init)
380		return;
381
382	pthread_mutex_lock(&mut);
383	if (init)
384		goto out;
385
386	real.socket = dlsym(RTLD_NEXT, "socket");
387	real.bind = dlsym(RTLD_NEXT, "bind");
388	real.listen = dlsym(RTLD_NEXT, "listen");
389	real.accept = dlsym(RTLD_NEXT, "accept");
390	real.connect = dlsym(RTLD_NEXT, "connect");
391	real.recv = dlsym(RTLD_NEXT, "recv");
392	real.recvfrom = dlsym(RTLD_NEXT, "recvfrom");
393	real.recvmsg = dlsym(RTLD_NEXT, "recvmsg");
394	real.read = dlsym(RTLD_NEXT, "read");
395	real.readv = dlsym(RTLD_NEXT, "readv");
396	real.send = dlsym(RTLD_NEXT, "send");
397	real.sendto = dlsym(RTLD_NEXT, "sendto");
398	real.sendmsg = dlsym(RTLD_NEXT, "sendmsg");
399	real.write = dlsym(RTLD_NEXT, "write");
400	real.writev = dlsym(RTLD_NEXT, "writev");
401	real.poll = dlsym(RTLD_NEXT, "poll");
402	real.shutdown = dlsym(RTLD_NEXT, "shutdown");
403	real.close = dlsym(RTLD_NEXT, "close");
404	real.getpeername = dlsym(RTLD_NEXT, "getpeername");
405	real.getsockname = dlsym(RTLD_NEXT, "getsockname");
406	real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
407	real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
408	real.fcntl = dlsym(RTLD_NEXT, "fcntl");
409	real.dup2 = dlsym(RTLD_NEXT, "dup2");
410	real.sendfile = dlsym(RTLD_NEXT, "sendfile");
411	real.fxstat = dlsym(RTLD_NEXT, "__fxstat");
412
413	rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
414	rs.bind = dlsym(RTLD_DEFAULT, "rbind");
415	rs.listen = dlsym(RTLD_DEFAULT, "rlisten");
416	rs.accept = dlsym(RTLD_DEFAULT, "raccept");
417	rs.connect = dlsym(RTLD_DEFAULT, "rconnect");
418	rs.recv = dlsym(RTLD_DEFAULT, "rrecv");
419	rs.recvfrom = dlsym(RTLD_DEFAULT, "rrecvfrom");
420	rs.recvmsg = dlsym(RTLD_DEFAULT, "rrecvmsg");
421	rs.read = dlsym(RTLD_DEFAULT, "rread");
422	rs.readv = dlsym(RTLD_DEFAULT, "rreadv");
423	rs.send = dlsym(RTLD_DEFAULT, "rsend");
424	rs.sendto = dlsym(RTLD_DEFAULT, "rsendto");
425	rs.sendmsg = dlsym(RTLD_DEFAULT, "rsendmsg");
426	rs.write = dlsym(RTLD_DEFAULT, "rwrite");
427	rs.writev = dlsym(RTLD_DEFAULT, "rwritev");
428	rs.poll = dlsym(RTLD_DEFAULT, "rpoll");
429	rs.shutdown = dlsym(RTLD_DEFAULT, "rshutdown");
430	rs.close = dlsym(RTLD_DEFAULT, "rclose");
431	rs.getpeername = dlsym(RTLD_DEFAULT, "rgetpeername");
432	rs.getsockname = dlsym(RTLD_DEFAULT, "rgetsockname");
433	rs.setsockopt = dlsym(RTLD_DEFAULT, "rsetsockopt");
434	rs.getsockopt = dlsym(RTLD_DEFAULT, "rgetsockopt");
435	rs.fcntl = dlsym(RTLD_DEFAULT, "rfcntl");
436
437	getenv_options();
438	scan_config();
439	init = 1;
440out:
441	pthread_mutex_unlock(&mut);
442}
443
444/*
445 * We currently only handle copying a few common values.
446 */
447static int copysockopts(int dfd, int sfd, struct socket_calls *dapi,
448			struct socket_calls *sapi)
449{
450	socklen_t len;
451	int param, ret;
452
453	ret = sapi->fcntl(sfd, F_GETFL);
454	if (ret > 0)
455		ret = dapi->fcntl(dfd, F_SETFL, ret);
456	if (ret)
457		return ret;
458
459	len = sizeof param;
460	ret = sapi->getsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &param, &len);
461	if (param && !ret)
462		ret = dapi->setsockopt(dfd, SOL_SOCKET, SO_REUSEADDR, &param, len);
463	if (ret)
464		return ret;
465
466	len = sizeof param;
467	ret = sapi->getsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &param, &len);
468	if (param && !ret)
469		ret = dapi->setsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, &param, len);
470	if (ret)
471		return ret;
472
473	return 0;
474}
475
476/*
477 * Convert between an rsocket and a normal socket.
478 */
479static int transpose_socket(int socket, enum fd_type new_type)
480{
481	socklen_t len = 0;
482	int sfd, dfd, param, ret;
483	struct socket_calls *sapi, *dapi;
484
485	sfd = fd_getd(socket);
486	if (new_type == fd_rsocket) {
487		dapi = &rs;
488		sapi = &real;
489	} else {
490		dapi = &real;
491		sapi = &rs;
492	}
493
494	ret = sapi->getsockname(sfd, NULL, &len);
495	if (ret)
496		return ret;
497
498	param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET;
499	dfd = dapi->socket(param, SOCK_STREAM, 0);
500	if (dfd < 0)
501		return dfd;
502
503	ret = copysockopts(dfd, sfd, dapi, sapi);
504	if (ret)
505		goto err;
506
507	fd_store(socket, dfd, new_type, fd_ready);
508	return dfd;
509
510err:
511	dapi->close(dfd);
512	return ret;
513}
514
515/*
516 * Use defaults on failure.
517 */
518static void set_rsocket_options(int rsocket)
519{
520	if (sq_size)
521		rsetsockopt(rsocket, SOL_RDMA, RDMA_SQSIZE, &sq_size, sizeof sq_size);
522
523	if (rq_size)
524		rsetsockopt(rsocket, SOL_RDMA, RDMA_RQSIZE, &rq_size, sizeof rq_size);
525
526	if (sq_inline)
527		rsetsockopt(rsocket, SOL_RDMA, RDMA_INLINE, &sq_inline, sizeof sq_inline);
528}
529
530int socket(int domain, int type, int protocol)
531{
532	static __thread int recursive;
533	int index, ret;
534
535	init_preload();
536
537	if (recursive || !intercept_socket(domain, type, protocol))
538		goto real;
539
540	index = fd_open();
541	if (index < 0)
542		return index;
543
544	if (fork_support && (domain == PF_INET || domain == PF_INET6) &&
545	    (type == SOCK_STREAM) && (!protocol || protocol == IPPROTO_TCP)) {
546		ret = real.socket(domain, type, protocol);
547		if (ret < 0)
548			return ret;
549		fd_store(index, ret, fd_normal, fd_fork);
550		return index;
551	}
552
553	recursive = 1;
554	ret = rsocket(domain, type, protocol);
555	recursive = 0;
556	if (ret >= 0) {
557		fd_store(index, ret, fd_rsocket, fd_ready);
558		set_rsocket_options(ret);
559		return index;
560	}
561	fd_close(index, &ret);
562real:
563	return real.socket(domain, type, protocol);
564}
565
566int bind(int socket, const struct sockaddr *addr, socklen_t addrlen)
567{
568	int fd;
569	return (fd_get(socket, &fd) == fd_rsocket) ?
570		rbind(fd, addr, addrlen) : real.bind(fd, addr, addrlen);
571}
572
573int listen(int socket, int backlog)
574{
575	int fd, ret;
576	if (fd_get(socket, &fd) == fd_rsocket) {
577		ret = rlisten(fd, backlog);
578	} else {
579		ret = real.listen(fd, backlog);
580		if (!ret && fd_gets(socket) == fd_fork)
581			fd_store(socket, fd, fd_normal, fd_fork_listen);
582	}
583	return ret;
584}
585
586int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
587{
588	int fd, index, ret;
589
590	if (fd_get(socket, &fd) == fd_rsocket) {
591		index = fd_open();
592		if (index < 0)
593			return index;
594
595		ret = raccept(fd, addr, addrlen);
596		if (ret < 0) {
597			fd_close(index, &fd);
598			return ret;
599		}
600
601		fd_store(index, ret, fd_rsocket, fd_ready);
602		return index;
603	} else if (fd_gets(socket) == fd_fork_listen) {
604		index = fd_open();
605		if (index < 0)
606			return index;
607
608		ret = real.accept(fd, addr, addrlen);
609		if (ret < 0) {
610			fd_close(index, &fd);
611			return ret;
612		}
613
614		fd_store(index, ret, fd_normal, fd_fork_passive);
615		return index;
616	} else {
617		return real.accept(fd, addr, addrlen);
618	}
619}
620
621/*
622 * We can't fork RDMA connections and pass them from the parent to the child
623 * process.  Instead, we need to establish the RDMA connection after calling
624 * fork.  To do this, we delay establishing the RDMA connection until we try
625 * to send/receive on the server side.
626 */
627static void fork_active(int socket)
628{
629	struct sockaddr_storage addr;
630	int sfd, dfd, ret;
631	socklen_t len;
632	uint32_t msg;
633	long flags;
634
635	sfd = fd_getd(socket);
636
637	flags = real.fcntl(sfd, F_GETFL);
638	real.fcntl(sfd, F_SETFL, 0);
639	ret = real.recv(sfd, &msg, sizeof msg, MSG_PEEK);
640	real.fcntl(sfd, F_SETFL, flags);
641	if ((ret != sizeof msg) || msg)
642		goto err1;
643
644	len = sizeof addr;
645	ret = real.getpeername(sfd, (struct sockaddr *) &addr, &len);
646	if (ret)
647		goto err1;
648
649	dfd = rsocket(addr.ss_family, SOCK_STREAM, 0);
650	if (dfd < 0)
651		goto err1;
652
653	ret = rconnect(dfd, (struct sockaddr *) &addr, len);
654	if (ret)
655		goto err2;
656
657	set_rsocket_options(dfd);
658	copysockopts(dfd, sfd, &rs, &real);
659	real.shutdown(sfd, SHUT_RDWR);
660	real.close(sfd);
661	fd_store(socket, dfd, fd_rsocket, fd_ready);
662	return;
663
664err2:
665	rclose(dfd);
666err1:
667	fd_store(socket, sfd, fd_normal, fd_ready);
668}
669
670/*
671 * The server will start listening for the new connection, then send a
672 * message to the active side when the listen is ready.  This does leave
673 * fork unsupported in the following case: the server is nonblocking and
674 * calls select/poll waiting to receive data from the client.
675 */
676static void fork_passive(int socket)
677{
678	struct sockaddr_in6 sin6;
679	sem_t *sem;
680	int lfd, sfd, dfd, ret, param;
681	socklen_t len;
682	uint32_t msg;
683
684	sfd = fd_getd(socket);
685
686	len = sizeof sin6;
687	ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
688	if (ret)
689		goto out;
690	sin6.sin6_flowinfo = 0;
691	sin6.sin6_scope_id = 0;
692	memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr);
693
694	sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR,
695		       S_IRWXU | S_IRWXG, 1);
696	if (sem == SEM_FAILED) {
697		ret = -1;
698		goto out;
699	}
700
701	lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
702	if (lfd < 0) {
703		ret = lfd;
704		goto sclose;
705	}
706
707	param = 1;
708	rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, &param, sizeof param);
709
710	sem_wait(sem);
711	ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6);
712	if (ret)
713		goto lclose;
714
715	ret = rlisten(lfd, 1);
716	if (ret)
717		goto lclose;
718
719	msg = 0;
720	len = real.write(sfd, &msg, sizeof msg);
721	if (len != sizeof msg)
722		goto lclose;
723
724	dfd = raccept(lfd, NULL, NULL);
725	if (dfd < 0) {
726		ret  = dfd;
727		goto lclose;
728	}
729
730	set_rsocket_options(dfd);
731	copysockopts(dfd, sfd, &rs, &real);
732	real.shutdown(sfd, SHUT_RDWR);
733	real.close(sfd);
734	fd_store(socket, dfd, fd_rsocket, fd_ready);
735
736lclose:
737	rclose(lfd);
738	sem_post(sem);
739sclose:
740	sem_close(sem);
741out:
742	if (ret)
743		fd_store(socket, sfd, fd_normal, fd_ready);
744}
745
746static inline enum fd_type fd_fork_get(int index, int *fd)
747{
748	struct fd_info *fdi;
749
750	fdi = idm_lookup(&idm, index);
751	if (fdi) {
752		if (fdi->state == fd_fork_passive)
753			fork_passive(index);
754		else if (fdi->state == fd_fork_active)
755			fork_active(index);
756		*fd = fdi->fd;
757		return fdi->type;
758
759	} else {
760		*fd = index;
761		return fd_normal;
762	}
763}
764
765int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
766{
767	int fd, ret;
768
769	if (fd_get(socket, &fd) == fd_rsocket) {
770		ret = rconnect(fd, addr, addrlen);
771		if (!ret || errno == EINPROGRESS)
772			return ret;
773
774		ret = transpose_socket(socket, fd_normal);
775		if (ret < 0)
776			return ret;
777
778		rclose(fd);
779		fd = ret;
780	} else if (fd_gets(socket) == fd_fork) {
781		fd_store(socket, fd, fd_normal, fd_fork_active);
782	}
783
784	return real.connect(fd, addr, addrlen);
785}
786
787ssize_t recv(int socket, void *buf, size_t len, int flags)
788{
789	int fd;
790	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
791		rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags);
792}
793
794ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
795		 struct sockaddr *src_addr, socklen_t *addrlen)
796{
797	int fd;
798	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
799		rrecvfrom(fd, buf, len, flags, src_addr, addrlen) :
800		real.recvfrom(fd, buf, len, flags, src_addr, addrlen);
801}
802
803ssize_t recvmsg(int socket, struct msghdr *msg, int flags)
804{
805	int fd;
806	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
807		rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags);
808}
809
810ssize_t read(int socket, void *buf, size_t count)
811{
812	int fd;
813	init_preload();
814	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
815		rread(fd, buf, count) : real.read(fd, buf, count);
816}
817
818ssize_t readv(int socket, const struct iovec *iov, int iovcnt)
819{
820	int fd;
821	init_preload();
822	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
823		rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt);
824}
825
826ssize_t send(int socket, const void *buf, size_t len, int flags)
827{
828	int fd;
829	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
830		rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags);
831}
832
833ssize_t sendto(int socket, const void *buf, size_t len, int flags,
834		const struct sockaddr *dest_addr, socklen_t addrlen)
835{
836	int fd;
837	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
838		rsendto(fd, buf, len, flags, dest_addr, addrlen) :
839		real.sendto(fd, buf, len, flags, dest_addr, addrlen);
840}
841
842ssize_t sendmsg(int socket, const struct msghdr *msg, int flags)
843{
844	int fd;
845	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
846		rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags);
847}
848
849ssize_t write(int socket, const void *buf, size_t count)
850{
851	int fd;
852	init_preload();
853	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
854		rwrite(fd, buf, count) : real.write(fd, buf, count);
855}
856
857ssize_t writev(int socket, const struct iovec *iov, int iovcnt)
858{
859	int fd;
860	init_preload();
861	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
862		rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt);
863}
864
865static struct pollfd *fds_alloc(nfds_t nfds)
866{
867	static __thread struct pollfd *rfds;
868	static __thread nfds_t rnfds;
869
870	if (nfds > rnfds) {
871		if (rfds)
872			free(rfds);
873
874		rfds = malloc(sizeof(*rfds) * nfds);
875		rnfds = rfds ? nfds : 0;
876	}
877
878	return rfds;
879}
880
881int poll(struct pollfd *fds, nfds_t nfds, int timeout)
882{
883	struct pollfd *rfds;
884	int i, ret;
885
886	init_preload();
887	for (i = 0; i < nfds; i++) {
888		if (fd_gett(fds[i].fd) == fd_rsocket)
889			goto use_rpoll;
890	}
891
892	return real.poll(fds, nfds, timeout);
893
894use_rpoll:
895	rfds = fds_alloc(nfds);
896	if (!rfds)
897		return ERR(ENOMEM);
898
899	for (i = 0; i < nfds; i++) {
900		rfds[i].fd = fd_getd(fds[i].fd);
901		rfds[i].events = fds[i].events;
902		rfds[i].revents = 0;
903	}
904
905	ret = rpoll(rfds, nfds, timeout);
906
907	for (i = 0; i < nfds; i++)
908		fds[i].revents = rfds[i].revents;
909
910	return ret;
911}
912
913static void select_to_rpoll(struct pollfd *fds, int *nfds,
914			    fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
915{
916	int fd, events, i = 0;
917
918	for (fd = 0; fd < *nfds; fd++) {
919		events = (readfds && FD_ISSET(fd, readfds)) ? POLLIN : 0;
920		if (writefds && FD_ISSET(fd, writefds))
921			events |= POLLOUT;
922
923		if (events || (exceptfds && FD_ISSET(fd, exceptfds))) {
924			fds[i].fd = fd_getd(fd);
925			fds[i++].events = events;
926		}
927	}
928
929	*nfds = i;
930}
931
932static int rpoll_to_select(struct pollfd *fds, int nfds,
933			   fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
934{
935	int fd, rfd, i, cnt = 0;
936
937	for (i = 0, fd = 0; i < nfds; fd++) {
938		rfd = fd_getd(fd);
939		if (rfd != fds[i].fd)
940			continue;
941
942		if (readfds && (fds[i].revents & POLLIN)) {
943			FD_SET(fd, readfds);
944			cnt++;
945		}
946
947		if (writefds && (fds[i].revents & POLLOUT)) {
948			FD_SET(fd, writefds);
949			cnt++;
950		}
951
952		if (exceptfds && (fds[i].revents & ~(POLLIN | POLLOUT))) {
953			FD_SET(fd, exceptfds);
954			cnt++;
955		}
956		i++;
957	}
958
959	return cnt;
960}
961
962static int rs_convert_timeout(struct timeval *timeout)
963{
964	return !timeout ? -1 : timeout->tv_sec * 1000 + timeout->tv_usec / 1000;
965}
966
967int select(int nfds, fd_set *readfds, fd_set *writefds,
968	   fd_set *exceptfds, struct timeval *timeout)
969{
970	struct pollfd *fds;
971	int ret;
972
973	fds = fds_alloc(nfds);
974	if (!fds)
975		return ERR(ENOMEM);
976
977	select_to_rpoll(fds, &nfds, readfds, writefds, exceptfds);
978	ret = rpoll(fds, nfds, rs_convert_timeout(timeout));
979
980	if (readfds)
981		FD_ZERO(readfds);
982	if (writefds)
983		FD_ZERO(writefds);
984	if (exceptfds)
985		FD_ZERO(exceptfds);
986
987	if (ret > 0)
988		ret = rpoll_to_select(fds, nfds, readfds, writefds, exceptfds);
989
990	return ret;
991}
992
993int shutdown(int socket, int how)
994{
995	int fd;
996	return (fd_get(socket, &fd) == fd_rsocket) ?
997		rshutdown(fd, how) : real.shutdown(fd, how);
998}
999
1000int close(int socket)
1001{
1002	struct fd_info *fdi;
1003	int ret;
1004
1005	init_preload();
1006	fdi = idm_lookup(&idm, socket);
1007	if (!fdi)
1008		return real.close(socket);
1009
1010	if (fdi->dupfd != -1) {
1011		ret = close(fdi->dupfd);
1012		if (ret)
1013			return ret;
1014	}
1015
1016	if (atomic_fetch_sub(&fdi->refcnt, 1) != 1)
1017		return 0;
1018
1019	idm_clear(&idm, socket);
1020	real.close(socket);
1021	ret = (fdi->type == fd_rsocket) ? rclose(fdi->fd) : real.close(fdi->fd);
1022	free(fdi);
1023	return ret;
1024}
1025
1026int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
1027{
1028	int fd;
1029	return (fd_get(socket, &fd) == fd_rsocket) ?
1030		rgetpeername(fd, addr, addrlen) :
1031		real.getpeername(fd, addr, addrlen);
1032}
1033
1034int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
1035{
1036	int fd;
1037	init_preload();
1038	return (fd_get(socket, &fd) == fd_rsocket) ?
1039		rgetsockname(fd, addr, addrlen) :
1040		real.getsockname(fd, addr, addrlen);
1041}
1042
1043int setsockopt(int socket, int level, int optname,
1044		const void *optval, socklen_t optlen)
1045{
1046	int fd;
1047	return (fd_get(socket, &fd) == fd_rsocket) ?
1048		rsetsockopt(fd, level, optname, optval, optlen) :
1049		real.setsockopt(fd, level, optname, optval, optlen);
1050}
1051
1052int getsockopt(int socket, int level, int optname,
1053		void *optval, socklen_t *optlen)
1054{
1055	int fd;
1056	return (fd_get(socket, &fd) == fd_rsocket) ?
1057		rgetsockopt(fd, level, optname, optval, optlen) :
1058		real.getsockopt(fd, level, optname, optval, optlen);
1059}
1060
1061int fcntl(int socket, int cmd, ... /* arg */)
1062{
1063	va_list args;
1064	long lparam;
1065	void *pparam;
1066	int fd, ret;
1067
1068	init_preload();
1069	va_start(args, cmd);
1070	switch (cmd) {
1071	case F_GETFD:
1072	case F_GETFL:
1073	case F_GETOWN:
1074	case F_GETSIG:
1075	case F_GETLEASE:
1076		ret = (fd_get(socket, &fd) == fd_rsocket) ?
1077			rfcntl(fd, cmd) : real.fcntl(fd, cmd);
1078		break;
1079	case F_DUPFD:
1080	/*case F_DUPFD_CLOEXEC:*/
1081	case F_SETFD:
1082	case F_SETFL:
1083	case F_SETOWN:
1084	case F_SETSIG:
1085	case F_SETLEASE:
1086	case F_NOTIFY:
1087		lparam = va_arg(args, long);
1088		ret = (fd_get(socket, &fd) == fd_rsocket) ?
1089			rfcntl(fd, cmd, lparam) : real.fcntl(fd, cmd, lparam);
1090		break;
1091	default:
1092		pparam = va_arg(args, void *);
1093		ret = (fd_get(socket, &fd) == fd_rsocket) ?
1094			rfcntl(fd, cmd, pparam) : real.fcntl(fd, cmd, pparam);
1095		break;
1096	}
1097	va_end(args);
1098	return ret;
1099}
1100
1101/*
1102 * dup2 is not thread safe
1103 */
1104int dup2(int oldfd, int newfd)
1105{
1106	struct fd_info *oldfdi, *newfdi;
1107	int ret;
1108
1109	init_preload();
1110	oldfdi = idm_lookup(&idm, oldfd);
1111	if (oldfdi) {
1112		if (oldfdi->state == fd_fork_passive)
1113			fork_passive(oldfd);
1114		else if (oldfdi->state == fd_fork_active)
1115			fork_active(oldfd);
1116	}
1117
1118	newfdi = idm_lookup(&idm, newfd);
1119	if (newfdi) {
1120		 /* newfd cannot have been dup'ed directly */
1121		if (atomic_load(&newfdi->refcnt) > 1)
1122			return ERR(EBUSY);
1123		close(newfd);
1124	}
1125
1126	ret = real.dup2(oldfd, newfd);
1127	if (!oldfdi || ret != newfd)
1128		return ret;
1129
1130	newfdi = calloc(1, sizeof(*newfdi));
1131	if (!newfdi) {
1132		close(newfd);
1133		return ERR(ENOMEM);
1134	}
1135
1136	pthread_mutex_lock(&mut);
1137	idm_set(&idm, newfd, newfdi);
1138	pthread_mutex_unlock(&mut);
1139
1140	newfdi->fd = oldfdi->fd;
1141	newfdi->type = oldfdi->type;
1142	if (oldfdi->dupfd != -1) {
1143		newfdi->dupfd = oldfdi->dupfd;
1144		oldfdi = idm_lookup(&idm, oldfdi->dupfd);
1145	} else {
1146		newfdi->dupfd = oldfd;
1147	}
1148	atomic_store(&newfdi->refcnt, 1);
1149	atomic_fetch_add(&oldfdi->refcnt, 1);
1150	return newfd;
1151}
1152
1153ssize_t sendfile(int out_fd, int in_fd, off_t *offset, size_t count)
1154{
1155	void *file_addr;
1156	int fd;
1157	size_t ret;
1158
1159	if (fd_get(out_fd, &fd) != fd_rsocket)
1160		return real.sendfile(fd, in_fd, offset, count);
1161
1162	file_addr = mmap(NULL, count, PROT_READ, 0, in_fd, offset ? *offset : 0);
1163	if (file_addr == (void *) -1)
1164		return -1;
1165
1166	ret = rwrite(fd, file_addr, count);
1167	if ((ret > 0) && offset)
1168		lseek(in_fd, ret, SEEK_CUR);
1169	munmap(file_addr, count);
1170	return ret;
1171}
1172
1173int __fxstat(int ver, int socket, struct stat *buf)
1174{
1175	int fd, ret;
1176
1177	init_preload();
1178	if (fd_get(socket, &fd) == fd_rsocket) {
1179		ret = real.fxstat(ver, socket, buf);
1180		if (!ret)
1181			buf->st_mode = (buf->st_mode & ~S_IFMT) | __S_IFSOCK;
1182	} else {
1183		ret = real.fxstat(ver, fd, buf);
1184	}
1185	return ret;
1186}
1187