1// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2#include <errno.h>
3#include <poll.h>
4#include <string.h>
5#include <stdlib.h>
6#include <stdio.h>
7#include <unistd.h>
8#include <linux/types.h>
9#include <linux/genetlink.h>
10#include <sys/socket.h>
11
12#include "ynl.h"
13
14#define ARRAY_SIZE(arr)		(sizeof(arr) / sizeof(*arr))
15
16#define __yerr_msg(yse, _msg...)					\
17	({								\
18		struct ynl_error *_yse = (yse);				\
19									\
20		if (_yse) {						\
21			snprintf(_yse->msg, sizeof(_yse->msg) - 1,  _msg); \
22			_yse->msg[sizeof(_yse->msg) - 1] = 0;		\
23		}							\
24	})
25
26#define __yerr_code(yse, _code...)		\
27	({					\
28		struct ynl_error *_yse = (yse);	\
29						\
30		if (_yse) {			\
31			_yse->code = _code;	\
32		}				\
33	})
34
35#define __yerr(yse, _code, _msg...)		\
36	({					\
37		__yerr_msg(yse, _msg);		\
38		__yerr_code(yse, _code);	\
39	})
40
41#define __perr(yse, _msg)		__yerr(yse, errno, _msg)
42
43#define yerr_msg(_ys, _msg...)		__yerr_msg(&(_ys)->err, _msg)
44#define yerr(_ys, _code, _msg...)	__yerr(&(_ys)->err, _code, _msg)
45#define perr(_ys, _msg)			__yerr(&(_ys)->err, errno, _msg)
46
47/* -- Netlink boiler plate */
48static int
49ynl_err_walk_report_one(struct ynl_policy_nest *policy, unsigned int type,
50			char *str, int str_sz, int *n)
51{
52	if (!policy) {
53		if (*n < str_sz)
54			*n += snprintf(str, str_sz, "!policy");
55		return 1;
56	}
57
58	if (type > policy->max_attr) {
59		if (*n < str_sz)
60			*n += snprintf(str, str_sz, "!oob");
61		return 1;
62	}
63
64	if (!policy->table[type].name) {
65		if (*n < str_sz)
66			*n += snprintf(str, str_sz, "!name");
67		return 1;
68	}
69
70	if (*n < str_sz)
71		*n += snprintf(str, str_sz - *n,
72			       ".%s", policy->table[type].name);
73	return 0;
74}
75
76static int
77ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off,
78	     struct ynl_policy_nest *policy, char *str, int str_sz,
79	     struct ynl_policy_nest **nest_pol)
80{
81	unsigned int astart_off, aend_off;
82	const struct nlattr *attr;
83	unsigned int data_len;
84	unsigned int type;
85	bool found = false;
86	int n = 0;
87
88	if (!policy) {
89		if (n < str_sz)
90			n += snprintf(str, str_sz, "!policy");
91		return n;
92	}
93
94	data_len = end - start;
95
96	ynl_attr_for_each_payload(start, data_len, attr) {
97		astart_off = (char *)attr - (char *)start;
98		aend_off = astart_off + ynl_attr_data_len(attr);
99		if (aend_off <= off)
100			continue;
101
102		found = true;
103		break;
104	}
105	if (!found)
106		return 0;
107
108	off -= astart_off;
109
110	type = ynl_attr_type(attr);
111
112	if (ynl_err_walk_report_one(policy, type, str, str_sz, &n))
113		return n;
114
115	if (!off) {
116		if (nest_pol)
117			*nest_pol = policy->table[type].nest;
118		return n;
119	}
120
121	if (!policy->table[type].nest) {
122		if (n < str_sz)
123			n += snprintf(str, str_sz, "!nest");
124		return n;
125	}
126
127	off -= sizeof(struct nlattr);
128	start =  ynl_attr_data(attr);
129	end = start + ynl_attr_data_len(attr);
130
131	return n + ynl_err_walk(ys, start, end, off, policy->table[type].nest,
132				&str[n], str_sz - n, nest_pol);
133}
134
135#define NLMSGERR_ATTR_MISS_TYPE (NLMSGERR_ATTR_POLICY + 1)
136#define NLMSGERR_ATTR_MISS_NEST (NLMSGERR_ATTR_POLICY + 2)
137#define NLMSGERR_ATTR_MAX (NLMSGERR_ATTR_MAX + 2)
138
139static int
140ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh,
141		  unsigned int hlen)
142{
143	const struct nlattr *tb[NLMSGERR_ATTR_MAX + 1] = {};
144	char miss_attr[sizeof(ys->err.msg)];
145	char bad_attr[sizeof(ys->err.msg)];
146	const struct nlattr *attr;
147	const char *str = NULL;
148
149	if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) {
150		yerr_msg(ys, "%s", strerror(ys->err.code));
151		return YNL_PARSE_CB_OK;
152	}
153
154	ynl_attr_for_each(attr, nlh, hlen) {
155		unsigned int len, type;
156
157		len = ynl_attr_data_len(attr);
158		type = ynl_attr_type(attr);
159
160		if (type > NLMSGERR_ATTR_MAX)
161			continue;
162
163		tb[type] = attr;
164
165		switch (type) {
166		case NLMSGERR_ATTR_OFFS:
167		case NLMSGERR_ATTR_MISS_TYPE:
168		case NLMSGERR_ATTR_MISS_NEST:
169			if (len != sizeof(__u32))
170				return YNL_PARSE_CB_ERROR;
171			break;
172		case NLMSGERR_ATTR_MSG:
173			str = ynl_attr_get_str(attr);
174			if (str[len - 1])
175				return YNL_PARSE_CB_ERROR;
176			break;
177		default:
178			break;
179		}
180	}
181
182	bad_attr[0] = '\0';
183	miss_attr[0] = '\0';
184
185	if (tb[NLMSGERR_ATTR_OFFS]) {
186		unsigned int n, off;
187		void *start, *end;
188
189		ys->err.attr_offs = ynl_attr_get_u32(tb[NLMSGERR_ATTR_OFFS]);
190
191		n = snprintf(bad_attr, sizeof(bad_attr), "%sbad attribute: ",
192			     str ? " (" : "");
193
194		start = ynl_nlmsg_data_offset(ys->nlh, ys->family->hdr_len);
195		end = ynl_nlmsg_end_addr(ys->nlh);
196
197		off = ys->err.attr_offs;
198		off -= sizeof(struct nlmsghdr);
199		off -= ys->family->hdr_len;
200
201		n += ynl_err_walk(ys, start, end, off, ys->req_policy,
202				  &bad_attr[n], sizeof(bad_attr) - n, NULL);
203
204		if (n >= sizeof(bad_attr))
205			n = sizeof(bad_attr) - 1;
206		bad_attr[n] = '\0';
207	}
208	if (tb[NLMSGERR_ATTR_MISS_TYPE]) {
209		struct ynl_policy_nest *nest_pol = NULL;
210		unsigned int n, off, type;
211		void *start, *end;
212		int n2;
213
214		type = ynl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_TYPE]);
215
216		n = snprintf(miss_attr, sizeof(miss_attr), "%smissing attribute: ",
217			     bad_attr[0] ? ", " : (str ? " (" : ""));
218
219		start = ynl_nlmsg_data_offset(ys->nlh, ys->family->hdr_len);
220		end = ynl_nlmsg_end_addr(ys->nlh);
221
222		nest_pol = ys->req_policy;
223		if (tb[NLMSGERR_ATTR_MISS_NEST]) {
224			off = ynl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_NEST]);
225			off -= sizeof(struct nlmsghdr);
226			off -= ys->family->hdr_len;
227
228			n += ynl_err_walk(ys, start, end, off, ys->req_policy,
229					  &miss_attr[n], sizeof(miss_attr) - n,
230					  &nest_pol);
231		}
232
233		n2 = 0;
234		ynl_err_walk_report_one(nest_pol, type, &miss_attr[n],
235					sizeof(miss_attr) - n, &n2);
236		n += n2;
237
238		if (n >= sizeof(miss_attr))
239			n = sizeof(miss_attr) - 1;
240		miss_attr[n] = '\0';
241	}
242
243	/* Implicitly depend on ys->err.code already set */
244	if (str)
245		yerr_msg(ys, "Kernel %s: '%s'%s%s%s",
246			 ys->err.code ? "error" : "warning",
247			 str, bad_attr, miss_attr,
248			 bad_attr[0] || miss_attr[0] ? ")" : "");
249	else if (bad_attr[0] || miss_attr[0])
250		yerr_msg(ys, "Kernel %s: %s%s",
251			 ys->err.code ? "error" : "warning",
252			 bad_attr, miss_attr);
253	else
254		yerr_msg(ys, "%s", strerror(ys->err.code));
255
256	return YNL_PARSE_CB_OK;
257}
258
259static int
260ynl_cb_error(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg)
261{
262	const struct nlmsgerr *err = ynl_nlmsg_data(nlh);
263	unsigned int hlen;
264	int code;
265
266	code = err->error >= 0 ? err->error : -err->error;
267	yarg->ys->err.code = code;
268	errno = code;
269
270	hlen = sizeof(*err);
271	if (!(nlh->nlmsg_flags & NLM_F_CAPPED))
272		hlen += ynl_nlmsg_data_len(&err->msg);
273
274	ynl_ext_ack_check(yarg->ys, nlh, hlen);
275
276	return code ? YNL_PARSE_CB_ERROR : YNL_PARSE_CB_STOP;
277}
278
279static int ynl_cb_done(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg)
280{
281	int err;
282
283	err = *(int *)NLMSG_DATA(nlh);
284	if (err < 0) {
285		yarg->ys->err.code = -err;
286		errno = -err;
287
288		ynl_ext_ack_check(yarg->ys, nlh, sizeof(int));
289
290		return YNL_PARSE_CB_ERROR;
291	}
292	return YNL_PARSE_CB_STOP;
293}
294
295/* Attribute validation */
296
297int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr)
298{
299	struct ynl_policy_attr *policy;
300	unsigned int type, len;
301	unsigned char *data;
302
303	data = ynl_attr_data(attr);
304	len = ynl_attr_data_len(attr);
305	type = ynl_attr_type(attr);
306	if (type > yarg->rsp_policy->max_attr) {
307		yerr(yarg->ys, YNL_ERROR_INTERNAL,
308		     "Internal error, validating unknown attribute");
309		return -1;
310	}
311
312	policy = &yarg->rsp_policy->table[type];
313
314	switch (policy->type) {
315	case YNL_PT_REJECT:
316		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
317		     "Rejected attribute (%s)", policy->name);
318		return -1;
319	case YNL_PT_IGNORE:
320		break;
321	case YNL_PT_U8:
322		if (len == sizeof(__u8))
323			break;
324		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
325		     "Invalid attribute (u8 %s)", policy->name);
326		return -1;
327	case YNL_PT_U16:
328		if (len == sizeof(__u16))
329			break;
330		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
331		     "Invalid attribute (u16 %s)", policy->name);
332		return -1;
333	case YNL_PT_U32:
334		if (len == sizeof(__u32))
335			break;
336		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
337		     "Invalid attribute (u32 %s)", policy->name);
338		return -1;
339	case YNL_PT_U64:
340		if (len == sizeof(__u64))
341			break;
342		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
343		     "Invalid attribute (u64 %s)", policy->name);
344		return -1;
345	case YNL_PT_UINT:
346		if (len == sizeof(__u32) || len == sizeof(__u64))
347			break;
348		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
349		     "Invalid attribute (uint %s)", policy->name);
350		return -1;
351	case YNL_PT_FLAG:
352		/* Let flags grow into real attrs, why not.. */
353		break;
354	case YNL_PT_NEST:
355		if (!len || len >= sizeof(*attr))
356			break;
357		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
358		     "Invalid attribute (nest %s)", policy->name);
359		return -1;
360	case YNL_PT_BINARY:
361		if (!policy->len || len == policy->len)
362			break;
363		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
364		     "Invalid attribute (binary %s)", policy->name);
365		return -1;
366	case YNL_PT_NUL_STR:
367		if ((!policy->len || len <= policy->len) && !data[len - 1])
368			break;
369		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
370		     "Invalid attribute (string %s)", policy->name);
371		return -1;
372	case YNL_PT_BITFIELD32:
373		if (len == sizeof(struct nla_bitfield32))
374			break;
375		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
376		     "Invalid attribute (bitfield32 %s)", policy->name);
377		return -1;
378	default:
379		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
380		     "Invalid attribute (unknown %s)", policy->name);
381		return -1;
382	}
383
384	return 0;
385}
386
387/* Generic code */
388
389static void ynl_err_reset(struct ynl_sock *ys)
390{
391	ys->err.code = 0;
392	ys->err.attr_offs = 0;
393	ys->err.msg[0] = 0;
394}
395
396struct nlmsghdr *ynl_msg_start(struct ynl_sock *ys, __u32 id, __u16 flags)
397{
398	struct nlmsghdr *nlh;
399
400	ynl_err_reset(ys);
401
402	nlh = ys->nlh = ynl_nlmsg_put_header(ys->tx_buf);
403	nlh->nlmsg_type	= id;
404	nlh->nlmsg_flags = flags;
405	nlh->nlmsg_seq = ++ys->seq;
406
407	/* This is a local YNL hack for length checking, we put the buffer
408	 * length in nlmsg_pid, since messages sent to the kernel always use
409	 * PID 0. Message needs to be terminated with ynl_msg_end().
410	 */
411	nlh->nlmsg_pid = YNL_SOCKET_BUFFER_SIZE;
412
413	return nlh;
414}
415
416static int ynl_msg_end(struct ynl_sock *ys, struct nlmsghdr *nlh)
417{
418	/* We stash buffer length in nlmsg_pid. */
419	if (nlh->nlmsg_pid == 0) {
420		yerr(ys, YNL_ERROR_INPUT_INVALID,
421		     "Unknown input buffer length");
422		return -EINVAL;
423	}
424	if (nlh->nlmsg_pid == YNL_MSG_OVERFLOW) {
425		yerr(ys, YNL_ERROR_INPUT_TOO_BIG,
426		     "Constructed message longer than internal buffer");
427		return -EMSGSIZE;
428	}
429
430	nlh->nlmsg_pid = 0;
431	return 0;
432}
433
434struct nlmsghdr *
435ynl_gemsg_start(struct ynl_sock *ys, __u32 id, __u16 flags,
436		__u8 cmd, __u8 version)
437{
438	struct genlmsghdr gehdr;
439	struct nlmsghdr *nlh;
440	void *data;
441
442	nlh = ynl_msg_start(ys, id, flags);
443
444	memset(&gehdr, 0, sizeof(gehdr));
445	gehdr.cmd = cmd;
446	gehdr.version = version;
447
448	data = ynl_nlmsg_put_extra_header(nlh, sizeof(gehdr));
449	memcpy(data, &gehdr, sizeof(gehdr));
450
451	return nlh;
452}
453
454void ynl_msg_start_req(struct ynl_sock *ys, __u32 id)
455{
456	ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK);
457}
458
459void ynl_msg_start_dump(struct ynl_sock *ys, __u32 id)
460{
461	ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP);
462}
463
464struct nlmsghdr *
465ynl_gemsg_start_req(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
466{
467	return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK, cmd, version);
468}
469
470struct nlmsghdr *
471ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
472{
473	return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP,
474			       cmd, version);
475}
476
477static int ynl_cb_null(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg)
478{
479	yerr(yarg->ys, YNL_ERROR_UNEXPECT_MSG,
480	     "Received a message when none were expected");
481
482	return YNL_PARSE_CB_ERROR;
483}
484
485static int
486__ynl_sock_read_msgs(struct ynl_parse_arg *yarg, ynl_parse_cb_t cb, int flags)
487{
488	struct ynl_sock *ys = yarg->ys;
489	const struct nlmsghdr *nlh;
490	ssize_t len, rem;
491	int ret;
492
493	len = recv(ys->socket, ys->rx_buf, YNL_SOCKET_BUFFER_SIZE, flags);
494	if (len < 0) {
495		if (flags & MSG_DONTWAIT && errno == EAGAIN)
496			return YNL_PARSE_CB_STOP;
497		return len;
498	}
499
500	ret = YNL_PARSE_CB_STOP;
501	for (rem = len; rem > 0; NLMSG_NEXT(nlh, rem)) {
502		nlh = (struct nlmsghdr *)&ys->rx_buf[len - rem];
503		if (!NLMSG_OK(nlh, rem)) {
504			yerr(yarg->ys, YNL_ERROR_INV_RESP,
505			     "Invalid message or trailing data in the response.");
506			return YNL_PARSE_CB_ERROR;
507		}
508
509		if (nlh->nlmsg_flags & NLM_F_DUMP_INTR) {
510			/* TODO: handle this better */
511			yerr(yarg->ys, YNL_ERROR_DUMP_INTER,
512			     "Dump interrupted / inconsistent, please retry.");
513			return YNL_PARSE_CB_ERROR;
514		}
515
516		switch (nlh->nlmsg_type) {
517		case 0:
518			yerr(yarg->ys, YNL_ERROR_INV_RESP,
519			     "Invalid message type in the response.");
520			return YNL_PARSE_CB_ERROR;
521		case NLMSG_NOOP:
522		case NLMSG_OVERRUN ... NLMSG_MIN_TYPE - 1:
523			ret = YNL_PARSE_CB_OK;
524			break;
525		case NLMSG_ERROR:
526			ret = ynl_cb_error(nlh, yarg);
527			break;
528		case NLMSG_DONE:
529			ret = ynl_cb_done(nlh, yarg);
530			break;
531		default:
532			ret = cb(nlh, yarg);
533			break;
534		}
535	}
536
537	return ret;
538}
539
540static int ynl_sock_read_msgs(struct ynl_parse_arg *yarg, ynl_parse_cb_t cb)
541{
542	return __ynl_sock_read_msgs(yarg, cb, 0);
543}
544
545static int ynl_recv_ack(struct ynl_sock *ys, int ret)
546{
547	struct ynl_parse_arg yarg = { .ys = ys, };
548
549	if (!ret) {
550		yerr(ys, YNL_ERROR_EXPECT_ACK,
551		     "Expecting an ACK but nothing received");
552		return -1;
553	}
554
555	return ynl_sock_read_msgs(&yarg, ynl_cb_null);
556}
557
558/* Init/fini and genetlink boiler plate */
559static int
560ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts)
561{
562	const struct nlattr *entry, *attr;
563	unsigned int i;
564
565	ynl_attr_for_each_nested(attr, mcasts)
566		ys->n_mcast_groups++;
567
568	if (!ys->n_mcast_groups)
569		return 0;
570
571	ys->mcast_groups = calloc(ys->n_mcast_groups,
572				  sizeof(*ys->mcast_groups));
573	if (!ys->mcast_groups)
574		return YNL_PARSE_CB_ERROR;
575
576	i = 0;
577	ynl_attr_for_each_nested(entry, mcasts) {
578		ynl_attr_for_each_nested(attr, entry) {
579			if (ynl_attr_type(attr) == CTRL_ATTR_MCAST_GRP_ID)
580				ys->mcast_groups[i].id = ynl_attr_get_u32(attr);
581			if (ynl_attr_type(attr) == CTRL_ATTR_MCAST_GRP_NAME) {
582				strncpy(ys->mcast_groups[i].name,
583					ynl_attr_get_str(attr),
584					GENL_NAMSIZ - 1);
585				ys->mcast_groups[i].name[GENL_NAMSIZ - 1] = 0;
586			}
587		}
588		i++;
589	}
590
591	return 0;
592}
593
594static int
595ynl_get_family_info_cb(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg)
596{
597	struct ynl_sock *ys = yarg->ys;
598	const struct nlattr *attr;
599	bool found_id = true;
600
601	ynl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr)) {
602		if (ynl_attr_type(attr) == CTRL_ATTR_MCAST_GROUPS)
603			if (ynl_get_family_info_mcast(ys, attr))
604				return YNL_PARSE_CB_ERROR;
605
606		if (ynl_attr_type(attr) != CTRL_ATTR_FAMILY_ID)
607			continue;
608
609		if (ynl_attr_data_len(attr) != sizeof(__u16)) {
610			yerr(ys, YNL_ERROR_ATTR_INVALID, "Invalid family ID");
611			return YNL_PARSE_CB_ERROR;
612		}
613
614		ys->family_id = ynl_attr_get_u16(attr);
615		found_id = true;
616	}
617
618	if (!found_id) {
619		yerr(ys, YNL_ERROR_ATTR_MISSING, "Family ID missing");
620		return YNL_PARSE_CB_ERROR;
621	}
622	return YNL_PARSE_CB_OK;
623}
624
625static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name)
626{
627	struct ynl_parse_arg yarg = { .ys = ys, };
628	struct nlmsghdr *nlh;
629	int err;
630
631	nlh = ynl_gemsg_start_req(ys, GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 1);
632	ynl_attr_put_str(nlh, CTRL_ATTR_FAMILY_NAME, family_name);
633
634	err = ynl_msg_end(ys, nlh);
635	if (err < 0)
636		return err;
637
638	err = send(ys->socket, nlh, nlh->nlmsg_len, 0);
639	if (err < 0) {
640		perr(ys, "failed to request socket family info");
641		return err;
642	}
643
644	err = ynl_sock_read_msgs(&yarg, ynl_get_family_info_cb);
645	if (err < 0) {
646		free(ys->mcast_groups);
647		perr(ys, "failed to receive the socket family info - no such family?");
648		return err;
649	}
650
651	err = ynl_recv_ack(ys, err);
652	if (err < 0) {
653		free(ys->mcast_groups);
654		return err;
655	}
656
657	return 0;
658}
659
660struct ynl_sock *
661ynl_sock_create(const struct ynl_family *yf, struct ynl_error *yse)
662{
663	struct sockaddr_nl addr;
664	struct ynl_sock *ys;
665	socklen_t addrlen;
666	int one = 1;
667
668	ys = malloc(sizeof(*ys) + 2 * YNL_SOCKET_BUFFER_SIZE);
669	if (!ys)
670		return NULL;
671	memset(ys, 0, sizeof(*ys));
672
673	ys->family = yf;
674	ys->tx_buf = &ys->raw_buf[0];
675	ys->rx_buf = &ys->raw_buf[YNL_SOCKET_BUFFER_SIZE];
676	ys->ntf_last_next = &ys->ntf_first;
677
678	ys->socket = socket(AF_NETLINK, SOCK_RAW, NETLINK_GENERIC);
679	if (ys->socket < 0) {
680		__perr(yse, "failed to create a netlink socket");
681		goto err_free_sock;
682	}
683
684	if (setsockopt(ys->socket, SOL_NETLINK, NETLINK_CAP_ACK,
685		       &one, sizeof(one))) {
686		__perr(yse, "failed to enable netlink ACK");
687		goto err_close_sock;
688	}
689	if (setsockopt(ys->socket, SOL_NETLINK, NETLINK_EXT_ACK,
690		       &one, sizeof(one))) {
691		__perr(yse, "failed to enable netlink ext ACK");
692		goto err_close_sock;
693	}
694
695	memset(&addr, 0, sizeof(addr));
696	addr.nl_family = AF_NETLINK;
697	if (bind(ys->socket, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
698		__perr(yse, "unable to bind to a socket address");
699		goto err_close_sock;;
700	}
701
702	memset(&addr, 0, sizeof(addr));
703	addrlen = sizeof(addr);
704	if (getsockname(ys->socket, (struct sockaddr *)&addr, &addrlen) < 0) {
705		__perr(yse, "unable to read socket address");
706		goto err_close_sock;;
707	}
708	ys->portid = addr.nl_pid;
709	ys->seq = random();
710
711
712	if (ynl_sock_read_family(ys, yf->name)) {
713		if (yse)
714			memcpy(yse, &ys->err, sizeof(*yse));
715		goto err_close_sock;
716	}
717
718	return ys;
719
720err_close_sock:
721	close(ys->socket);
722err_free_sock:
723	free(ys);
724	return NULL;
725}
726
727void ynl_sock_destroy(struct ynl_sock *ys)
728{
729	struct ynl_ntf_base_type *ntf;
730
731	close(ys->socket);
732	while ((ntf = ynl_ntf_dequeue(ys)))
733		ynl_ntf_free(ntf);
734	free(ys->mcast_groups);
735	free(ys);
736}
737
738/* YNL multicast handling */
739
740void ynl_ntf_free(struct ynl_ntf_base_type *ntf)
741{
742	ntf->free(ntf);
743}
744
745int ynl_subscribe(struct ynl_sock *ys, const char *grp_name)
746{
747	unsigned int i;
748	int err;
749
750	for (i = 0; i < ys->n_mcast_groups; i++)
751		if (!strcmp(ys->mcast_groups[i].name, grp_name))
752			break;
753	if (i == ys->n_mcast_groups) {
754		yerr(ys, ENOENT, "Multicast group '%s' not found", grp_name);
755		return -1;
756	}
757
758	err = setsockopt(ys->socket, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
759			 &ys->mcast_groups[i].id,
760			 sizeof(ys->mcast_groups[i].id));
761	if (err < 0) {
762		perr(ys, "Subscribing to multicast group failed");
763		return -1;
764	}
765
766	return 0;
767}
768
769int ynl_socket_get_fd(struct ynl_sock *ys)
770{
771	return ys->socket;
772}
773
774struct ynl_ntf_base_type *ynl_ntf_dequeue(struct ynl_sock *ys)
775{
776	struct ynl_ntf_base_type *ntf;
777
778	if (!ynl_has_ntf(ys))
779		return NULL;
780
781	ntf = ys->ntf_first;
782	ys->ntf_first = ntf->next;
783	if (ys->ntf_last_next == &ntf->next)
784		ys->ntf_last_next = &ys->ntf_first;
785
786	return ntf;
787}
788
789static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh)
790{
791	struct ynl_parse_arg yarg = { .ys = ys, };
792	const struct ynl_ntf_info *info;
793	struct ynl_ntf_base_type *rsp;
794	struct genlmsghdr *gehdr;
795	int ret;
796
797	gehdr = ynl_nlmsg_data(nlh);
798	if (gehdr->cmd >= ys->family->ntf_info_size)
799		return YNL_PARSE_CB_ERROR;
800	info = &ys->family->ntf_info[gehdr->cmd];
801	if (!info->cb)
802		return YNL_PARSE_CB_ERROR;
803
804	rsp = calloc(1, info->alloc_sz);
805	rsp->free = info->free;
806	yarg.data = rsp->data;
807	yarg.rsp_policy = info->policy;
808
809	ret = info->cb(nlh, &yarg);
810	if (ret <= YNL_PARSE_CB_STOP)
811		goto err_free;
812
813	rsp->family = nlh->nlmsg_type;
814	rsp->cmd = gehdr->cmd;
815
816	*ys->ntf_last_next = rsp;
817	ys->ntf_last_next = &rsp->next;
818
819	return YNL_PARSE_CB_OK;
820
821err_free:
822	info->free(rsp);
823	return YNL_PARSE_CB_ERROR;
824}
825
826static int
827ynl_ntf_trampoline(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg)
828{
829	return ynl_ntf_parse(yarg->ys, nlh);
830}
831
832int ynl_ntf_check(struct ynl_sock *ys)
833{
834	struct ynl_parse_arg yarg = { .ys = ys, };
835	int err;
836
837	do {
838		err = __ynl_sock_read_msgs(&yarg, ynl_ntf_trampoline,
839					   MSG_DONTWAIT);
840		if (err < 0)
841			return err;
842	} while (err > 0);
843
844	return 0;
845}
846
847/* YNL specific helpers used by the auto-generated code */
848
849struct ynl_dump_list_type *YNL_LIST_END = (void *)(0xb4d123);
850
851void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd)
852{
853	yerr(ys, YNL_ERROR_UNKNOWN_NTF,
854	     "Unknown notification message type '%d'", cmd);
855}
856
857int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg)
858{
859	yerr(yarg->ys, YNL_ERROR_INV_RESP, "Error parsing response: %s", msg);
860	return YNL_PARSE_CB_ERROR;
861}
862
863static int
864ynl_check_alien(struct ynl_sock *ys, const struct nlmsghdr *nlh, __u32 rsp_cmd)
865{
866	struct genlmsghdr *gehdr;
867
868	if (ynl_nlmsg_data_len(nlh) < sizeof(*gehdr)) {
869		yerr(ys, YNL_ERROR_INV_RESP,
870		     "Kernel responded with truncated message");
871		return -1;
872	}
873
874	gehdr = ynl_nlmsg_data(nlh);
875	if (gehdr->cmd != rsp_cmd)
876		return ynl_ntf_parse(ys, nlh);
877
878	return 0;
879}
880
881static
882int ynl_req_trampoline(const struct nlmsghdr *nlh, struct ynl_parse_arg *yarg)
883{
884	struct ynl_req_state *yrs = (void *)yarg;
885	int ret;
886
887	ret = ynl_check_alien(yrs->yarg.ys, nlh, yrs->rsp_cmd);
888	if (ret)
889		return ret < 0 ? YNL_PARSE_CB_ERROR : YNL_PARSE_CB_OK;
890
891	return yrs->cb(nlh, &yrs->yarg);
892}
893
894int ynl_exec(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
895	     struct ynl_req_state *yrs)
896{
897	int err;
898
899	err = ynl_msg_end(ys, req_nlh);
900	if (err < 0)
901		return err;
902
903	err = send(ys->socket, req_nlh, req_nlh->nlmsg_len, 0);
904	if (err < 0)
905		return err;
906
907	do {
908		err = ynl_sock_read_msgs(&yrs->yarg, ynl_req_trampoline);
909	} while (err > 0);
910
911	return err;
912}
913
914static int
915ynl_dump_trampoline(const struct nlmsghdr *nlh, struct ynl_parse_arg *data)
916{
917	struct ynl_dump_state *ds = (void *)data;
918	struct ynl_dump_list_type *obj;
919	struct ynl_parse_arg yarg = {};
920	int ret;
921
922	ret = ynl_check_alien(ds->yarg.ys, nlh, ds->rsp_cmd);
923	if (ret)
924		return ret < 0 ? YNL_PARSE_CB_ERROR : YNL_PARSE_CB_OK;
925
926	obj = calloc(1, ds->alloc_sz);
927	if (!obj)
928		return YNL_PARSE_CB_ERROR;
929
930	if (!ds->first)
931		ds->first = obj;
932	if (ds->last)
933		ds->last->next = obj;
934	ds->last = obj;
935
936	yarg = ds->yarg;
937	yarg.data = &obj->data;
938
939	return ds->cb(nlh, &yarg);
940}
941
942static void *ynl_dump_end(struct ynl_dump_state *ds)
943{
944	if (!ds->first)
945		return YNL_LIST_END;
946
947	ds->last->next = YNL_LIST_END;
948	return ds->first;
949}
950
951int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
952		  struct ynl_dump_state *yds)
953{
954	int err;
955
956	err = ynl_msg_end(ys, req_nlh);
957	if (err < 0)
958		return err;
959
960	err = send(ys->socket, req_nlh, req_nlh->nlmsg_len, 0);
961	if (err < 0)
962		return err;
963
964	do {
965		err = ynl_sock_read_msgs(&yds->yarg, ynl_dump_trampoline);
966		if (err < 0)
967			goto err_close_list;
968	} while (err > 0);
969
970	yds->first = ynl_dump_end(yds);
971	return 0;
972
973err_close_list:
974	yds->first = ynl_dump_end(yds);
975	return -1;
976}
977