1/*	$OpenBSD: mproc.c,v 1.40 2024/01/20 09:01:03 claudio Exp $	*/
2
3/*
4 * Copyright (c) 2012 Eric Faurot <eric@faurot.net>
5 *
6 * Permission to use, copy, modify, and distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <errno.h>
20#include <stdlib.h>
21#include <string.h>
22#include <unistd.h>
23
24#include "smtpd.h"
25#include "log.h"
26
27static void mproc_dispatch(int, short, void *);
28
29static ssize_t imsg_read_nofd(struct imsgbuf *);
30
31int
32mproc_fork(struct mproc *p, const char *path, char *argv[])
33{
34	int sp[2];
35
36	if (socketpair(AF_UNIX, SOCK_STREAM, PF_UNSPEC, sp) == -1)
37		return (-1);
38
39	io_set_nonblocking(sp[0]);
40	io_set_nonblocking(sp[1]);
41
42	if ((p->pid = fork()) == -1)
43		goto err;
44
45	if (p->pid == 0) {
46		/* child process */
47		dup2(sp[0], STDIN_FILENO);
48		if (closefrom(STDERR_FILENO + 1) == -1)
49			exit(1);
50
51		execv(path, argv);
52		fatal("execv: %s", path);
53	}
54
55	/* parent process */
56	close(sp[0]);
57	mproc_init(p, sp[1]);
58	return (0);
59
60err:
61	log_warn("warn: Failed to start process %s, instance of %s", argv[0], path);
62	close(sp[0]);
63	close(sp[1]);
64	return (-1);
65}
66
67void
68mproc_init(struct mproc *p, int fd)
69{
70	imsg_init(&p->imsgbuf, fd);
71}
72
73void
74mproc_clear(struct mproc *p)
75{
76	log_debug("debug: clearing p=%s, fd=%d, pid=%d", p->name, p->imsgbuf.fd, p->pid);
77
78	if (p->events)
79		event_del(&p->ev);
80	close(p->imsgbuf.fd);
81	imsg_clear(&p->imsgbuf);
82}
83
84void
85mproc_enable(struct mproc *p)
86{
87	if (p->enable == 0) {
88		log_trace(TRACE_MPROC, "mproc: %s -> %s: enabled",
89		    proc_name(smtpd_process),
90		    proc_name(p->proc));
91		p->enable = 1;
92	}
93	mproc_event_add(p);
94}
95
96void
97mproc_disable(struct mproc *p)
98{
99	if (p->enable == 1) {
100		log_trace(TRACE_MPROC, "mproc: %s -> %s: disabled",
101		    proc_name(smtpd_process),
102		    proc_name(p->proc));
103		p->enable = 0;
104	}
105	mproc_event_add(p);
106}
107
108void
109mproc_event_add(struct mproc *p)
110{
111	short	events;
112
113	if (p->enable)
114		events = EV_READ;
115	else
116		events = 0;
117
118	if (p->imsgbuf.w.queued)
119		events |= EV_WRITE;
120
121	if (p->events)
122		event_del(&p->ev);
123
124	p->events = events;
125	if (events) {
126		event_set(&p->ev, p->imsgbuf.fd, events, mproc_dispatch, p);
127		event_add(&p->ev, NULL);
128	}
129}
130
131static void
132mproc_dispatch(int fd, short event, void *arg)
133{
134	struct mproc	*p = arg;
135	struct imsg	 imsg;
136	ssize_t		 n;
137
138	p->events = 0;
139
140	if (event & EV_READ) {
141
142		if (p->proc == PROC_CLIENT)
143			n = imsg_read_nofd(&p->imsgbuf);
144		else
145			n = imsg_read(&p->imsgbuf);
146
147		switch (n) {
148		case -1:
149			if (errno == EAGAIN)
150				break;
151			log_warn("warn: %s -> %s: imsg_read",
152			    proc_name(smtpd_process),  p->name);
153			fatal("exiting");
154			/* NOTREACHED */
155		case 0:
156			/* this pipe is dead, so remove the event handler */
157			log_debug("debug: %s -> %s: pipe closed",
158			    proc_name(smtpd_process),  p->name);
159			p->handler(p, NULL);
160			return;
161		default:
162			break;
163		}
164	}
165
166	if (event & EV_WRITE) {
167		n = msgbuf_write(&p->imsgbuf.w);
168		if (n == 0 || (n == -1 && errno != EAGAIN)) {
169			/* this pipe is dead, so remove the event handler */
170			log_debug("debug: %s -> %s: pipe closed",
171			    proc_name(smtpd_process),  p->name);
172			p->handler(p, NULL);
173			return;
174		}
175	}
176
177	for (;;) {
178		if ((n = imsg_get(&p->imsgbuf, &imsg)) == -1) {
179
180			if (smtpd_process == PROC_CONTROL &&
181			    p->proc == PROC_CLIENT) {
182				log_warnx("warn: client sent invalid imsg "
183				    "over control socket");
184				p->handler(p, NULL);
185				return;
186			}
187			log_warn("fatal: %s: error in imsg_get for %s",
188			    proc_name(smtpd_process),  p->name);
189			fatalx(NULL);
190		}
191		if (n == 0)
192			break;
193
194		p->handler(p, &imsg);
195
196		imsg_free(&imsg);
197	}
198
199	mproc_event_add(p);
200}
201
202/* This should go into libutil */
203static ssize_t
204imsg_read_nofd(struct imsgbuf *ibuf)
205{
206	ssize_t	 n;
207	char	*buf;
208	size_t	 len;
209
210	buf = ibuf->r.buf + ibuf->r.wpos;
211	len = sizeof(ibuf->r.buf) - ibuf->r.wpos;
212
213	while ((n = recv(ibuf->fd, buf, len, 0)) == -1) {
214		if (errno != EINTR)
215			return (n);
216	}
217
218	ibuf->r.wpos += n;
219	return (n);
220}
221
222void
223m_forward(struct mproc *p, struct imsg *imsg)
224{
225	imsg_compose(&p->imsgbuf, imsg->hdr.type, imsg->hdr.peerid,
226	    imsg->hdr.pid, imsg_get_fd(imsg), imsg->data,
227	    imsg->hdr.len - sizeof(imsg->hdr));
228
229	if (imsg->hdr.type != IMSG_STAT_DECREMENT &&
230	    imsg->hdr.type != IMSG_STAT_INCREMENT)
231		log_trace(TRACE_MPROC, "mproc: %s -> %s : %zu %s (forward)",
232		    proc_name(smtpd_process),
233		    proc_name(p->proc),
234		    imsg->hdr.len - sizeof(imsg->hdr),
235		    imsg_to_str(imsg->hdr.type));
236
237	mproc_event_add(p);
238}
239
240void
241m_compose(struct mproc *p, uint32_t type, uint32_t peerid, pid_t pid, int fd,
242    void *data, size_t len)
243{
244	imsg_compose(&p->imsgbuf, type, peerid, pid, fd, data, len);
245
246	if (type != IMSG_STAT_DECREMENT &&
247	    type != IMSG_STAT_INCREMENT)
248		log_trace(TRACE_MPROC, "mproc: %s -> %s : %zu %s",
249		    proc_name(smtpd_process),
250		    proc_name(p->proc),
251		    len,
252		    imsg_to_str(type));
253
254	mproc_event_add(p);
255}
256
257void
258m_composev(struct mproc *p, uint32_t type, uint32_t peerid, pid_t pid,
259    int fd, const struct iovec *iov, int n)
260{
261	size_t	len;
262	int	i;
263
264	imsg_composev(&p->imsgbuf, type, peerid, pid, fd, iov, n);
265
266	len = 0;
267	for (i = 0; i < n; i++)
268		len += iov[i].iov_len;
269
270	if (type != IMSG_STAT_DECREMENT &&
271	    type != IMSG_STAT_INCREMENT)
272		log_trace(TRACE_MPROC, "mproc: %s -> %s : %zu %s",
273		    proc_name(smtpd_process),
274		    proc_name(p->proc),
275		    len,
276		    imsg_to_str(type));
277
278	mproc_event_add(p);
279}
280
281void
282m_create(struct mproc *p, uint32_t type, uint32_t peerid, pid_t pid, int fd)
283{
284	p->m_pos = 0;
285	p->m_type = type;
286	p->m_peerid = peerid;
287	p->m_pid = pid;
288	p->m_fd = fd;
289}
290
291void
292m_add(struct mproc *p, const void *data, size_t len)
293{
294	size_t	 alloc;
295	void	*tmp;
296
297	if (p->m_pos + len + IMSG_HEADER_SIZE > MAX_IMSGSIZE) {
298		log_warnx("warn: message too large");
299		fatal(NULL);
300	}
301
302	alloc = p->m_alloc ? p->m_alloc : 128;
303	while (p->m_pos + len > alloc)
304		alloc *= 2;
305	if (alloc != p->m_alloc) {
306		log_trace(TRACE_MPROC, "mproc: %s -> %s: realloc %zu -> %zu",
307		    proc_name(smtpd_process),
308		    proc_name(p->proc),
309		    p->m_alloc,
310		    alloc);
311
312		tmp = recallocarray(p->m_buf, p->m_alloc, alloc, 1);
313		if (tmp == NULL)
314			fatal("realloc");
315		p->m_alloc = alloc;
316		p->m_buf = tmp;
317	}
318
319	memmove(p->m_buf + p->m_pos, data, len);
320	p->m_pos += len;
321}
322
323void
324m_close(struct mproc *p)
325{
326	if (imsg_compose(&p->imsgbuf, p->m_type, p->m_peerid, p->m_pid, p->m_fd,
327	    p->m_buf, p->m_pos) == -1)
328		fatal("imsg_compose");
329
330	log_trace(TRACE_MPROC, "mproc: %s -> %s : %zu %s",
331		    proc_name(smtpd_process),
332		    proc_name(p->proc),
333		    p->m_pos,
334		    imsg_to_str(p->m_type));
335
336	mproc_event_add(p);
337}
338
339void
340m_flush(struct mproc *p)
341{
342	if (imsg_compose(&p->imsgbuf, p->m_type, p->m_peerid, p->m_pid, p->m_fd,
343	    p->m_buf, p->m_pos) == -1)
344		fatal("imsg_compose");
345
346	log_trace(TRACE_MPROC, "mproc: %s -> %s : %zu %s (flush)",
347	    proc_name(smtpd_process),
348	    proc_name(p->proc),
349	    p->m_pos,
350	    imsg_to_str(p->m_type));
351
352	p->m_pos = 0;
353
354	if (imsg_flush(&p->imsgbuf) == -1)
355		fatal("imsg_flush");
356}
357
358static struct imsg * current;
359
360static void
361m_error(const char *error)
362{
363	char	buf[512];
364
365	(void)snprintf(buf, sizeof buf, "%s: %s: %s",
366	    proc_name(smtpd_process),
367	    imsg_to_str(current->hdr.type),
368	    error);
369	fatalx("%s", buf);
370}
371
372void
373m_msg(struct msg *m, struct imsg *imsg)
374{
375	current = imsg;
376	m->pos = imsg->data;
377	m->end = m->pos + (imsg->hdr.len - sizeof(imsg->hdr));
378}
379
380void
381m_end(struct msg *m)
382{
383	if (m->pos != m->end)
384		m_error("not at msg end");
385}
386
387int
388m_is_eom(struct msg *m)
389{
390	return (m->pos == m->end);
391}
392
393static inline void
394m_get(struct msg *m, void *dst, size_t sz)
395{
396	if (sz > MAX_IMSGSIZE ||
397	    m->end - m->pos < (ssize_t)sz)
398		fatalx("msg too short");
399
400	memmove(dst, m->pos, sz);
401	m->pos += sz;
402}
403
404void
405m_add_int(struct mproc *m, int v)
406{
407	m_add(m, &v, sizeof(v));
408};
409
410void
411m_add_u32(struct mproc *m, uint32_t u32)
412{
413	m_add(m, &u32, sizeof(u32));
414};
415
416void
417m_add_size(struct mproc *m, size_t sz)
418{
419	m_add(m, &sz, sizeof(sz));
420};
421
422void
423m_add_time(struct mproc *m, time_t v)
424{
425	m_add(m, &v, sizeof(v));
426};
427
428void
429m_add_timeval(struct mproc *m, struct timeval *tv)
430{
431	m_add(m, tv, sizeof(*tv));
432}
433
434
435void
436m_add_string(struct mproc *m, const char *v)
437{
438	if (v) {
439		m_add(m, "s", 1);
440		m_add(m, v, strlen(v) + 1);
441	}
442	else
443		m_add(m, "\0", 1);
444};
445
446void
447m_add_data(struct mproc *m, const void *v, size_t len)
448{
449	m_add_size(m, len);
450	m_add(m, v, len);
451};
452
453void
454m_add_id(struct mproc *m, uint64_t v)
455{
456	m_add(m, &v, sizeof(v));
457}
458
459void
460m_add_evpid(struct mproc *m, uint64_t v)
461{
462	m_add(m, &v, sizeof(v));
463}
464
465void
466m_add_msgid(struct mproc *m, uint32_t v)
467{
468	m_add(m, &v, sizeof(v));
469}
470
471void
472m_add_sockaddr(struct mproc *m, const struct sockaddr *sa)
473{
474	m_add_size(m, sa->sa_len);
475	m_add(m, sa, sa->sa_len);
476}
477
478void
479m_add_mailaddr(struct mproc *m, const struct mailaddr *maddr)
480{
481	m_add(m, maddr, sizeof(*maddr));
482}
483
484void
485m_add_envelope(struct mproc *m, const struct envelope *evp)
486{
487	char	buf[sizeof(*evp)];
488
489	envelope_dump_buffer(evp, buf, sizeof(buf));
490	m_add_evpid(m, evp->id);
491	m_add_string(m, buf);
492}
493
494void
495m_add_params(struct mproc *m, struct dict *d)
496{
497	const char *key;
498	char *value;
499	void *iter;
500
501	if (d == NULL) {
502		m_add_size(m, 0);
503		return;
504	}
505	m_add_size(m, dict_count(d));
506	iter = NULL;
507	while (dict_iter(d, &iter, &key, (void **)&value)) {
508		m_add_string(m, key);
509		m_add_string(m, value);
510	}
511}
512
513void
514m_get_int(struct msg *m, int *i)
515{
516	m_get(m, i, sizeof(*i));
517}
518
519void
520m_get_u32(struct msg *m, uint32_t *u32)
521{
522	m_get(m, u32, sizeof(*u32));
523}
524
525void
526m_get_size(struct msg *m, size_t *sz)
527{
528	m_get(m, sz, sizeof(*sz));
529}
530
531void
532m_get_time(struct msg *m, time_t *t)
533{
534	m_get(m, t, sizeof(*t));
535}
536
537void
538m_get_timeval(struct msg *m, struct timeval *tv)
539{
540	m_get(m, tv, sizeof(*tv));
541}
542
543void
544m_get_string(struct msg *m, const char **s)
545{
546	uint8_t	*end;
547	char c;
548
549	if (m->pos >= m->end)
550		m_error("msg too short");
551
552	c = *m->pos++;
553	if (c == '\0') {
554		*s = NULL;
555		return;
556	}
557
558	if (m->pos >= m->end)
559		m_error("msg too short");
560	end = memchr(m->pos, 0, m->end - m->pos);
561	if (end == NULL)
562		m_error("unterminated string");
563
564	*s = m->pos;
565	m->pos = end + 1;
566}
567
568void
569m_get_data(struct msg *m, const void **data, size_t *sz)
570{
571	m_get_size(m, sz);
572
573	if (*sz == 0) {
574		*data = NULL;
575		return;
576	}
577
578	if (m->pos + *sz > m->end)
579		m_error("msg too short");
580
581	*data = m->pos;
582	m->pos += *sz;
583}
584
585void
586m_get_evpid(struct msg *m, uint64_t *evpid)
587{
588	m_get(m, evpid, sizeof(*evpid));
589}
590
591void
592m_get_msgid(struct msg *m, uint32_t *msgid)
593{
594	m_get(m, msgid, sizeof(*msgid));
595}
596
597void
598m_get_id(struct msg *m, uint64_t *id)
599{
600	m_get(m, id, sizeof(*id));
601}
602
603void
604m_get_sockaddr(struct msg *m, struct sockaddr *sa)
605{
606	size_t len;
607
608	m_get_size(m, &len);
609	m_get(m, sa, len);
610}
611
612void
613m_get_mailaddr(struct msg *m, struct mailaddr *maddr)
614{
615	m_get(m, maddr, sizeof(*maddr));
616}
617
618void
619m_get_envelope(struct msg *m, struct envelope *evp)
620{
621	uint64_t	 evpid;
622	const char	*buf;
623
624	m_get_evpid(m, &evpid);
625	m_get_string(m, &buf);
626	if (buf == NULL)
627		fatalx("empty envelope buffer");
628
629	if (!envelope_load_buffer(evp, buf, strlen(buf)))
630		fatalx("failed to retrieve envelope");
631	evp->id = evpid;
632}
633
634void
635m_get_params(struct msg *m, struct dict *d)
636{
637	size_t	c;
638	const char *key;
639	const char *value;
640	char *tmp;
641
642	dict_init(d);
643
644	m_get_size(m, &c);
645
646	for (; c; c--) {
647		m_get_string(m, &key);
648		m_get_string(m, &value);
649		if ((tmp = strdup(value)) == NULL)
650			fatal("m_get_params");
651		dict_set(d, key, tmp);
652	}
653}
654
655void
656m_clear_params(struct dict *d)
657{
658	char *value;
659
660	while (dict_poproot(d, (void **)&value))
661		free(value);
662}
663