1// SPDX-License-Identifier: GPL-2.0-only
2/* Copyright (c) 2019 Facebook */
3
4#include <linux/bpf.h>
5#include <linux/bpf_verifier.h>
6#include <linux/btf.h>
7#include <linux/filter.h>
8#include <linux/slab.h>
9#include <linux/numa.h>
10#include <linux/seq_file.h>
11#include <linux/refcount.h>
12#include <linux/mutex.h>
13#include <linux/btf_ids.h>
14#include <linux/rcupdate_wait.h>
15
16struct bpf_struct_ops_value {
17	struct bpf_struct_ops_common_value common;
18	char data[] ____cacheline_aligned_in_smp;
19};
20
21#define MAX_TRAMP_IMAGE_PAGES 8
22
23struct bpf_struct_ops_map {
24	struct bpf_map map;
25	struct rcu_head rcu;
26	const struct bpf_struct_ops_desc *st_ops_desc;
27	/* protect map_update */
28	struct mutex lock;
29	/* link has all the bpf_links that is populated
30	 * to the func ptr of the kernel's struct
31	 * (in kvalue.data).
32	 */
33	struct bpf_link **links;
34	u32 links_cnt;
35	u32 image_pages_cnt;
36	/* image_pages is an array of pages that has all the trampolines
37	 * that stores the func args before calling the bpf_prog.
38	 */
39	void *image_pages[MAX_TRAMP_IMAGE_PAGES];
40	/* The owner moduler's btf. */
41	struct btf *btf;
42	/* uvalue->data stores the kernel struct
43	 * (e.g. tcp_congestion_ops) that is more useful
44	 * to userspace than the kvalue.  For example,
45	 * the bpf_prog's id is stored instead of the kernel
46	 * address of a func ptr.
47	 */
48	struct bpf_struct_ops_value *uvalue;
49	/* kvalue.data stores the actual kernel's struct
50	 * (e.g. tcp_congestion_ops) that will be
51	 * registered to the kernel subsystem.
52	 */
53	struct bpf_struct_ops_value kvalue;
54};
55
56struct bpf_struct_ops_link {
57	struct bpf_link link;
58	struct bpf_map __rcu *map;
59};
60
61static DEFINE_MUTEX(update_mutex);
62
63#define VALUE_PREFIX "bpf_struct_ops_"
64#define VALUE_PREFIX_LEN (sizeof(VALUE_PREFIX) - 1)
65
66const struct bpf_verifier_ops bpf_struct_ops_verifier_ops = {
67};
68
69const struct bpf_prog_ops bpf_struct_ops_prog_ops = {
70#ifdef CONFIG_NET
71	.test_run = bpf_struct_ops_test_run,
72#endif
73};
74
75BTF_ID_LIST(st_ops_ids)
76BTF_ID(struct, module)
77BTF_ID(struct, bpf_struct_ops_common_value)
78
79enum {
80	IDX_MODULE_ID,
81	IDX_ST_OPS_COMMON_VALUE_ID,
82};
83
84extern struct btf *btf_vmlinux;
85
86static bool is_valid_value_type(struct btf *btf, s32 value_id,
87				const struct btf_type *type,
88				const char *value_name)
89{
90	const struct btf_type *common_value_type;
91	const struct btf_member *member;
92	const struct btf_type *vt, *mt;
93
94	vt = btf_type_by_id(btf, value_id);
95	if (btf_vlen(vt) != 2) {
96		pr_warn("The number of %s's members should be 2, but we get %d\n",
97			value_name, btf_vlen(vt));
98		return false;
99	}
100	member = btf_type_member(vt);
101	mt = btf_type_by_id(btf, member->type);
102	common_value_type = btf_type_by_id(btf_vmlinux,
103					   st_ops_ids[IDX_ST_OPS_COMMON_VALUE_ID]);
104	if (mt != common_value_type) {
105		pr_warn("The first member of %s should be bpf_struct_ops_common_value\n",
106			value_name);
107		return false;
108	}
109	member++;
110	mt = btf_type_by_id(btf, member->type);
111	if (mt != type) {
112		pr_warn("The second member of %s should be %s\n",
113			value_name, btf_name_by_offset(btf, type->name_off));
114		return false;
115	}
116
117	return true;
118}
119
120static void *bpf_struct_ops_image_alloc(void)
121{
122	void *image;
123	int err;
124
125	err = bpf_jit_charge_modmem(PAGE_SIZE);
126	if (err)
127		return ERR_PTR(err);
128	image = arch_alloc_bpf_trampoline(PAGE_SIZE);
129	if (!image) {
130		bpf_jit_uncharge_modmem(PAGE_SIZE);
131		return ERR_PTR(-ENOMEM);
132	}
133
134	return image;
135}
136
137void bpf_struct_ops_image_free(void *image)
138{
139	if (image) {
140		arch_free_bpf_trampoline(image, PAGE_SIZE);
141		bpf_jit_uncharge_modmem(PAGE_SIZE);
142	}
143}
144
145#define MAYBE_NULL_SUFFIX "__nullable"
146#define MAX_STUB_NAME 128
147
148/* Return the type info of a stub function, if it exists.
149 *
150 * The name of a stub function is made up of the name of the struct_ops and
151 * the name of the function pointer member, separated by "__". For example,
152 * if the struct_ops type is named "foo_ops" and the function pointer
153 * member is named "bar", the stub function name would be "foo_ops__bar".
154 */
155static const struct btf_type *
156find_stub_func_proto(const struct btf *btf, const char *st_op_name,
157		     const char *member_name)
158{
159	char stub_func_name[MAX_STUB_NAME];
160	const struct btf_type *func_type;
161	s32 btf_id;
162	int cp;
163
164	cp = snprintf(stub_func_name, MAX_STUB_NAME, "%s__%s",
165		      st_op_name, member_name);
166	if (cp >= MAX_STUB_NAME) {
167		pr_warn("Stub function name too long\n");
168		return NULL;
169	}
170	btf_id = btf_find_by_name_kind(btf, stub_func_name, BTF_KIND_FUNC);
171	if (btf_id < 0)
172		return NULL;
173	func_type = btf_type_by_id(btf, btf_id);
174	if (!func_type)
175		return NULL;
176
177	return btf_type_by_id(btf, func_type->type); /* FUNC_PROTO */
178}
179
180/* Prepare argument info for every nullable argument of a member of a
181 * struct_ops type.
182 *
183 * Initialize a struct bpf_struct_ops_arg_info according to type info of
184 * the arguments of a stub function. (Check kCFI for more information about
185 * stub functions.)
186 *
187 * Each member in the struct_ops type has a struct bpf_struct_ops_arg_info
188 * to provide an array of struct bpf_ctx_arg_aux, which in turn provides
189 * the information that used by the verifier to check the arguments of the
190 * BPF struct_ops program assigned to the member. Here, we only care about
191 * the arguments that are marked as __nullable.
192 *
193 * The array of struct bpf_ctx_arg_aux is eventually assigned to
194 * prog->aux->ctx_arg_info of BPF struct_ops programs and passed to the
195 * verifier. (See check_struct_ops_btf_id())
196 *
197 * arg_info->info will be the list of struct bpf_ctx_arg_aux if success. If
198 * fails, it will be kept untouched.
199 */
200static int prepare_arg_info(struct btf *btf,
201			    const char *st_ops_name,
202			    const char *member_name,
203			    const struct btf_type *func_proto,
204			    struct bpf_struct_ops_arg_info *arg_info)
205{
206	const struct btf_type *stub_func_proto, *pointed_type;
207	const struct btf_param *stub_args, *args;
208	struct bpf_ctx_arg_aux *info, *info_buf;
209	u32 nargs, arg_no, info_cnt = 0;
210	u32 arg_btf_id;
211	int offset;
212
213	stub_func_proto = find_stub_func_proto(btf, st_ops_name, member_name);
214	if (!stub_func_proto)
215		return 0;
216
217	/* Check if the number of arguments of the stub function is the same
218	 * as the number of arguments of the function pointer.
219	 */
220	nargs = btf_type_vlen(func_proto);
221	if (nargs != btf_type_vlen(stub_func_proto)) {
222		pr_warn("the number of arguments of the stub function %s__%s does not match the number of arguments of the member %s of struct %s\n",
223			st_ops_name, member_name, member_name, st_ops_name);
224		return -EINVAL;
225	}
226
227	if (!nargs)
228		return 0;
229
230	args = btf_params(func_proto);
231	stub_args = btf_params(stub_func_proto);
232
233	info_buf = kcalloc(nargs, sizeof(*info_buf), GFP_KERNEL);
234	if (!info_buf)
235		return -ENOMEM;
236
237	/* Prepare info for every nullable argument */
238	info = info_buf;
239	for (arg_no = 0; arg_no < nargs; arg_no++) {
240		/* Skip arguments that is not suffixed with
241		 * "__nullable".
242		 */
243		if (!btf_param_match_suffix(btf, &stub_args[arg_no],
244					    MAYBE_NULL_SUFFIX))
245			continue;
246
247		/* Should be a pointer to struct */
248		pointed_type = btf_type_resolve_ptr(btf,
249						    args[arg_no].type,
250						    &arg_btf_id);
251		if (!pointed_type ||
252		    !btf_type_is_struct(pointed_type)) {
253			pr_warn("stub function %s__%s has %s tagging to an unsupported type\n",
254				st_ops_name, member_name, MAYBE_NULL_SUFFIX);
255			goto err_out;
256		}
257
258		offset = btf_ctx_arg_offset(btf, func_proto, arg_no);
259		if (offset < 0) {
260			pr_warn("stub function %s__%s has an invalid trampoline ctx offset for arg#%u\n",
261				st_ops_name, member_name, arg_no);
262			goto err_out;
263		}
264
265		if (args[arg_no].type != stub_args[arg_no].type) {
266			pr_warn("arg#%u type in stub function %s__%s does not match with its original func_proto\n",
267				arg_no, st_ops_name, member_name);
268			goto err_out;
269		}
270
271		/* Fill the information of the new argument */
272		info->reg_type =
273			PTR_TRUSTED | PTR_TO_BTF_ID | PTR_MAYBE_NULL;
274		info->btf_id = arg_btf_id;
275		info->btf = btf;
276		info->offset = offset;
277
278		info++;
279		info_cnt++;
280	}
281
282	if (info_cnt) {
283		arg_info->info = info_buf;
284		arg_info->cnt = info_cnt;
285	} else {
286		kfree(info_buf);
287	}
288
289	return 0;
290
291err_out:
292	kfree(info_buf);
293
294	return -EINVAL;
295}
296
297/* Clean up the arg_info in a struct bpf_struct_ops_desc. */
298void bpf_struct_ops_desc_release(struct bpf_struct_ops_desc *st_ops_desc)
299{
300	struct bpf_struct_ops_arg_info *arg_info;
301	int i;
302
303	arg_info = st_ops_desc->arg_info;
304	for (i = 0; i < btf_type_vlen(st_ops_desc->type); i++)
305		kfree(arg_info[i].info);
306
307	kfree(arg_info);
308}
309
310int bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
311			     struct btf *btf,
312			     struct bpf_verifier_log *log)
313{
314	struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
315	struct bpf_struct_ops_arg_info *arg_info;
316	const struct btf_member *member;
317	const struct btf_type *t;
318	s32 type_id, value_id;
319	char value_name[128];
320	const char *mname;
321	int i, err;
322
323	if (strlen(st_ops->name) + VALUE_PREFIX_LEN >=
324	    sizeof(value_name)) {
325		pr_warn("struct_ops name %s is too long\n",
326			st_ops->name);
327		return -EINVAL;
328	}
329	sprintf(value_name, "%s%s", VALUE_PREFIX, st_ops->name);
330
331	if (!st_ops->cfi_stubs) {
332		pr_warn("struct_ops for %s has no cfi_stubs\n", st_ops->name);
333		return -EINVAL;
334	}
335
336	type_id = btf_find_by_name_kind(btf, st_ops->name,
337					BTF_KIND_STRUCT);
338	if (type_id < 0) {
339		pr_warn("Cannot find struct %s in %s\n",
340			st_ops->name, btf_get_name(btf));
341		return -EINVAL;
342	}
343	t = btf_type_by_id(btf, type_id);
344	if (btf_type_vlen(t) > BPF_STRUCT_OPS_MAX_NR_MEMBERS) {
345		pr_warn("Cannot support #%u members in struct %s\n",
346			btf_type_vlen(t), st_ops->name);
347		return -EINVAL;
348	}
349
350	value_id = btf_find_by_name_kind(btf, value_name,
351					 BTF_KIND_STRUCT);
352	if (value_id < 0) {
353		pr_warn("Cannot find struct %s in %s\n",
354			value_name, btf_get_name(btf));
355		return -EINVAL;
356	}
357	if (!is_valid_value_type(btf, value_id, t, value_name))
358		return -EINVAL;
359
360	arg_info = kcalloc(btf_type_vlen(t), sizeof(*arg_info),
361			   GFP_KERNEL);
362	if (!arg_info)
363		return -ENOMEM;
364
365	st_ops_desc->arg_info = arg_info;
366	st_ops_desc->type = t;
367	st_ops_desc->type_id = type_id;
368	st_ops_desc->value_id = value_id;
369	st_ops_desc->value_type = btf_type_by_id(btf, value_id);
370
371	for_each_member(i, t, member) {
372		const struct btf_type *func_proto;
373
374		mname = btf_name_by_offset(btf, member->name_off);
375		if (!*mname) {
376			pr_warn("anon member in struct %s is not supported\n",
377				st_ops->name);
378			err = -EOPNOTSUPP;
379			goto errout;
380		}
381
382		if (__btf_member_bitfield_size(t, member)) {
383			pr_warn("bit field member %s in struct %s is not supported\n",
384				mname, st_ops->name);
385			err = -EOPNOTSUPP;
386			goto errout;
387		}
388
389		func_proto = btf_type_resolve_func_ptr(btf,
390						       member->type,
391						       NULL);
392		if (!func_proto)
393			continue;
394
395		if (btf_distill_func_proto(log, btf,
396					   func_proto, mname,
397					   &st_ops->func_models[i])) {
398			pr_warn("Error in parsing func ptr %s in struct %s\n",
399				mname, st_ops->name);
400			err = -EINVAL;
401			goto errout;
402		}
403
404		err = prepare_arg_info(btf, st_ops->name, mname,
405				       func_proto,
406				       arg_info + i);
407		if (err)
408			goto errout;
409	}
410
411	if (st_ops->init(btf)) {
412		pr_warn("Error in init bpf_struct_ops %s\n",
413			st_ops->name);
414		err = -EINVAL;
415		goto errout;
416	}
417
418	return 0;
419
420errout:
421	bpf_struct_ops_desc_release(st_ops_desc);
422
423	return err;
424}
425
426static int bpf_struct_ops_map_get_next_key(struct bpf_map *map, void *key,
427					   void *next_key)
428{
429	if (key && *(u32 *)key == 0)
430		return -ENOENT;
431
432	*(u32 *)next_key = 0;
433	return 0;
434}
435
436int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
437				       void *value)
438{
439	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
440	struct bpf_struct_ops_value *uvalue, *kvalue;
441	enum bpf_struct_ops_state state;
442	s64 refcnt;
443
444	if (unlikely(*(u32 *)key != 0))
445		return -ENOENT;
446
447	kvalue = &st_map->kvalue;
448	/* Pair with smp_store_release() during map_update */
449	state = smp_load_acquire(&kvalue->common.state);
450	if (state == BPF_STRUCT_OPS_STATE_INIT) {
451		memset(value, 0, map->value_size);
452		return 0;
453	}
454
455	/* No lock is needed.  state and refcnt do not need
456	 * to be updated together under atomic context.
457	 */
458	uvalue = value;
459	memcpy(uvalue, st_map->uvalue, map->value_size);
460	uvalue->common.state = state;
461
462	/* This value offers the user space a general estimate of how
463	 * many sockets are still utilizing this struct_ops for TCP
464	 * congestion control. The number might not be exact, but it
465	 * should sufficiently meet our present goals.
466	 */
467	refcnt = atomic64_read(&map->refcnt) - atomic64_read(&map->usercnt);
468	refcount_set(&uvalue->common.refcnt, max_t(s64, refcnt, 0));
469
470	return 0;
471}
472
473static void *bpf_struct_ops_map_lookup_elem(struct bpf_map *map, void *key)
474{
475	return ERR_PTR(-EINVAL);
476}
477
478static void bpf_struct_ops_map_put_progs(struct bpf_struct_ops_map *st_map)
479{
480	u32 i;
481
482	for (i = 0; i < st_map->links_cnt; i++) {
483		if (st_map->links[i]) {
484			bpf_link_put(st_map->links[i]);
485			st_map->links[i] = NULL;
486		}
487	}
488}
489
490static void bpf_struct_ops_map_free_image(struct bpf_struct_ops_map *st_map)
491{
492	int i;
493
494	for (i = 0; i < st_map->image_pages_cnt; i++)
495		bpf_struct_ops_image_free(st_map->image_pages[i]);
496	st_map->image_pages_cnt = 0;
497}
498
499static int check_zero_holes(const struct btf *btf, const struct btf_type *t, void *data)
500{
501	const struct btf_member *member;
502	u32 i, moff, msize, prev_mend = 0;
503	const struct btf_type *mtype;
504
505	for_each_member(i, t, member) {
506		moff = __btf_member_bit_offset(t, member) / 8;
507		if (moff > prev_mend &&
508		    memchr_inv(data + prev_mend, 0, moff - prev_mend))
509			return -EINVAL;
510
511		mtype = btf_type_by_id(btf, member->type);
512		mtype = btf_resolve_size(btf, mtype, &msize);
513		if (IS_ERR(mtype))
514			return PTR_ERR(mtype);
515		prev_mend = moff + msize;
516	}
517
518	if (t->size > prev_mend &&
519	    memchr_inv(data + prev_mend, 0, t->size - prev_mend))
520		return -EINVAL;
521
522	return 0;
523}
524
525static void bpf_struct_ops_link_release(struct bpf_link *link)
526{
527}
528
529static void bpf_struct_ops_link_dealloc(struct bpf_link *link)
530{
531	struct bpf_tramp_link *tlink = container_of(link, struct bpf_tramp_link, link);
532
533	kfree(tlink);
534}
535
536const struct bpf_link_ops bpf_struct_ops_link_lops = {
537	.release = bpf_struct_ops_link_release,
538	.dealloc = bpf_struct_ops_link_dealloc,
539};
540
541int bpf_struct_ops_prepare_trampoline(struct bpf_tramp_links *tlinks,
542				      struct bpf_tramp_link *link,
543				      const struct btf_func_model *model,
544				      void *stub_func,
545				      void **_image, u32 *_image_off,
546				      bool allow_alloc)
547{
548	u32 image_off = *_image_off, flags = BPF_TRAMP_F_INDIRECT;
549	void *image = *_image;
550	int size;
551
552	tlinks[BPF_TRAMP_FENTRY].links[0] = link;
553	tlinks[BPF_TRAMP_FENTRY].nr_links = 1;
554
555	if (model->ret_size > 0)
556		flags |= BPF_TRAMP_F_RET_FENTRY_RET;
557
558	size = arch_bpf_trampoline_size(model, flags, tlinks, NULL);
559	if (size <= 0)
560		return size ? : -EFAULT;
561
562	/* Allocate image buffer if necessary */
563	if (!image || size > PAGE_SIZE - image_off) {
564		if (!allow_alloc)
565			return -E2BIG;
566
567		image = bpf_struct_ops_image_alloc();
568		if (IS_ERR(image))
569			return PTR_ERR(image);
570		image_off = 0;
571	}
572
573	size = arch_prepare_bpf_trampoline(NULL, image + image_off,
574					   image + PAGE_SIZE,
575					   model, flags, tlinks, stub_func);
576	if (size <= 0) {
577		if (image != *_image)
578			bpf_struct_ops_image_free(image);
579		return size ? : -EFAULT;
580	}
581
582	*_image = image;
583	*_image_off = image_off + size;
584	return 0;
585}
586
587static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
588					   void *value, u64 flags)
589{
590	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
591	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
592	const struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
593	struct bpf_struct_ops_value *uvalue, *kvalue;
594	const struct btf_type *module_type;
595	const struct btf_member *member;
596	const struct btf_type *t = st_ops_desc->type;
597	struct bpf_tramp_links *tlinks;
598	void *udata, *kdata;
599	int prog_fd, err;
600	u32 i, trampoline_start, image_off = 0;
601	void *cur_image = NULL, *image = NULL;
602
603	if (flags)
604		return -EINVAL;
605
606	if (*(u32 *)key != 0)
607		return -E2BIG;
608
609	err = check_zero_holes(st_map->btf, st_ops_desc->value_type, value);
610	if (err)
611		return err;
612
613	uvalue = value;
614	err = check_zero_holes(st_map->btf, t, uvalue->data);
615	if (err)
616		return err;
617
618	if (uvalue->common.state || refcount_read(&uvalue->common.refcnt))
619		return -EINVAL;
620
621	tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL);
622	if (!tlinks)
623		return -ENOMEM;
624
625	uvalue = (struct bpf_struct_ops_value *)st_map->uvalue;
626	kvalue = (struct bpf_struct_ops_value *)&st_map->kvalue;
627
628	mutex_lock(&st_map->lock);
629
630	if (kvalue->common.state != BPF_STRUCT_OPS_STATE_INIT) {
631		err = -EBUSY;
632		goto unlock;
633	}
634
635	memcpy(uvalue, value, map->value_size);
636
637	udata = &uvalue->data;
638	kdata = &kvalue->data;
639
640	module_type = btf_type_by_id(btf_vmlinux, st_ops_ids[IDX_MODULE_ID]);
641	for_each_member(i, t, member) {
642		const struct btf_type *mtype, *ptype;
643		struct bpf_prog *prog;
644		struct bpf_tramp_link *link;
645		u32 moff;
646
647		moff = __btf_member_bit_offset(t, member) / 8;
648		ptype = btf_type_resolve_ptr(st_map->btf, member->type, NULL);
649		if (ptype == module_type) {
650			if (*(void **)(udata + moff))
651				goto reset_unlock;
652			*(void **)(kdata + moff) = BPF_MODULE_OWNER;
653			continue;
654		}
655
656		err = st_ops->init_member(t, member, kdata, udata);
657		if (err < 0)
658			goto reset_unlock;
659
660		/* The ->init_member() has handled this member */
661		if (err > 0)
662			continue;
663
664		/* If st_ops->init_member does not handle it,
665		 * we will only handle func ptrs and zero-ed members
666		 * here.  Reject everything else.
667		 */
668
669		/* All non func ptr member must be 0 */
670		if (!ptype || !btf_type_is_func_proto(ptype)) {
671			u32 msize;
672
673			mtype = btf_type_by_id(st_map->btf, member->type);
674			mtype = btf_resolve_size(st_map->btf, mtype, &msize);
675			if (IS_ERR(mtype)) {
676				err = PTR_ERR(mtype);
677				goto reset_unlock;
678			}
679
680			if (memchr_inv(udata + moff, 0, msize)) {
681				err = -EINVAL;
682				goto reset_unlock;
683			}
684
685			continue;
686		}
687
688		prog_fd = (int)(*(unsigned long *)(udata + moff));
689		/* Similar check as the attr->attach_prog_fd */
690		if (!prog_fd)
691			continue;
692
693		prog = bpf_prog_get(prog_fd);
694		if (IS_ERR(prog)) {
695			err = PTR_ERR(prog);
696			goto reset_unlock;
697		}
698
699		if (prog->type != BPF_PROG_TYPE_STRUCT_OPS ||
700		    prog->aux->attach_btf_id != st_ops_desc->type_id ||
701		    prog->expected_attach_type != i) {
702			bpf_prog_put(prog);
703			err = -EINVAL;
704			goto reset_unlock;
705		}
706
707		link = kzalloc(sizeof(*link), GFP_USER);
708		if (!link) {
709			bpf_prog_put(prog);
710			err = -ENOMEM;
711			goto reset_unlock;
712		}
713		bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS,
714			      &bpf_struct_ops_link_lops, prog);
715		st_map->links[i] = &link->link;
716
717		trampoline_start = image_off;
718		err = bpf_struct_ops_prepare_trampoline(tlinks, link,
719						&st_ops->func_models[i],
720						*(void **)(st_ops->cfi_stubs + moff),
721						&image, &image_off,
722						st_map->image_pages_cnt < MAX_TRAMP_IMAGE_PAGES);
723		if (err)
724			goto reset_unlock;
725
726		if (cur_image != image) {
727			st_map->image_pages[st_map->image_pages_cnt++] = image;
728			cur_image = image;
729			trampoline_start = 0;
730		}
731		if (err < 0)
732			goto reset_unlock;
733
734		*(void **)(kdata + moff) = image + trampoline_start + cfi_get_offset();
735
736		/* put prog_id to udata */
737		*(unsigned long *)(udata + moff) = prog->aux->id;
738	}
739
740	if (st_ops->validate) {
741		err = st_ops->validate(kdata);
742		if (err)
743			goto reset_unlock;
744	}
745	for (i = 0; i < st_map->image_pages_cnt; i++)
746		arch_protect_bpf_trampoline(st_map->image_pages[i], PAGE_SIZE);
747
748	if (st_map->map.map_flags & BPF_F_LINK) {
749		err = 0;
750		/* Let bpf_link handle registration & unregistration.
751		 *
752		 * Pair with smp_load_acquire() during lookup_elem().
753		 */
754		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_READY);
755		goto unlock;
756	}
757
758	err = st_ops->reg(kdata);
759	if (likely(!err)) {
760		/* This refcnt increment on the map here after
761		 * 'st_ops->reg()' is secure since the state of the
762		 * map must be set to INIT at this moment, and thus
763		 * bpf_struct_ops_map_delete_elem() can't unregister
764		 * or transition it to TOBEFREE concurrently.
765		 */
766		bpf_map_inc(map);
767		/* Pair with smp_load_acquire() during lookup_elem().
768		 * It ensures the above udata updates (e.g. prog->aux->id)
769		 * can be seen once BPF_STRUCT_OPS_STATE_INUSE is set.
770		 */
771		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_INUSE);
772		goto unlock;
773	}
774
775	/* Error during st_ops->reg(). Can happen if this struct_ops needs to be
776	 * verified as a whole, after all init_member() calls. Can also happen if
777	 * there was a race in registering the struct_ops (under the same name) to
778	 * a sub-system through different struct_ops's maps.
779	 */
780
781reset_unlock:
782	bpf_struct_ops_map_free_image(st_map);
783	bpf_struct_ops_map_put_progs(st_map);
784	memset(uvalue, 0, map->value_size);
785	memset(kvalue, 0, map->value_size);
786unlock:
787	kfree(tlinks);
788	mutex_unlock(&st_map->lock);
789	return err;
790}
791
792static long bpf_struct_ops_map_delete_elem(struct bpf_map *map, void *key)
793{
794	enum bpf_struct_ops_state prev_state;
795	struct bpf_struct_ops_map *st_map;
796
797	st_map = (struct bpf_struct_ops_map *)map;
798	if (st_map->map.map_flags & BPF_F_LINK)
799		return -EOPNOTSUPP;
800
801	prev_state = cmpxchg(&st_map->kvalue.common.state,
802			     BPF_STRUCT_OPS_STATE_INUSE,
803			     BPF_STRUCT_OPS_STATE_TOBEFREE);
804	switch (prev_state) {
805	case BPF_STRUCT_OPS_STATE_INUSE:
806		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data);
807		bpf_map_put(map);
808		return 0;
809	case BPF_STRUCT_OPS_STATE_TOBEFREE:
810		return -EINPROGRESS;
811	case BPF_STRUCT_OPS_STATE_INIT:
812		return -ENOENT;
813	default:
814		WARN_ON_ONCE(1);
815		/* Should never happen.  Treat it as not found. */
816		return -ENOENT;
817	}
818}
819
820static void bpf_struct_ops_map_seq_show_elem(struct bpf_map *map, void *key,
821					     struct seq_file *m)
822{
823	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
824	void *value;
825	int err;
826
827	value = kmalloc(map->value_size, GFP_USER | __GFP_NOWARN);
828	if (!value)
829		return;
830
831	err = bpf_struct_ops_map_sys_lookup_elem(map, key, value);
832	if (!err) {
833		btf_type_seq_show(st_map->btf,
834				  map->btf_vmlinux_value_type_id,
835				  value, m);
836		seq_puts(m, "\n");
837	}
838
839	kfree(value);
840}
841
842static void __bpf_struct_ops_map_free(struct bpf_map *map)
843{
844	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
845
846	if (st_map->links)
847		bpf_struct_ops_map_put_progs(st_map);
848	bpf_map_area_free(st_map->links);
849	bpf_struct_ops_map_free_image(st_map);
850	bpf_map_area_free(st_map->uvalue);
851	bpf_map_area_free(st_map);
852}
853
854static void bpf_struct_ops_map_free(struct bpf_map *map)
855{
856	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
857
858	/* st_ops->owner was acquired during map_alloc to implicitly holds
859	 * the btf's refcnt. The acquire was only done when btf_is_module()
860	 * st_map->btf cannot be NULL here.
861	 */
862	if (btf_is_module(st_map->btf))
863		module_put(st_map->st_ops_desc->st_ops->owner);
864
865	/* The struct_ops's function may switch to another struct_ops.
866	 *
867	 * For example, bpf_tcp_cc_x->init() may switch to
868	 * another tcp_cc_y by calling
869	 * setsockopt(TCP_CONGESTION, "tcp_cc_y").
870	 * During the switch,  bpf_struct_ops_put(tcp_cc_x) is called
871	 * and its refcount may reach 0 which then free its
872	 * trampoline image while tcp_cc_x is still running.
873	 *
874	 * A vanilla rcu gp is to wait for all bpf-tcp-cc prog
875	 * to finish. bpf-tcp-cc prog is non sleepable.
876	 * A rcu_tasks gp is to wait for the last few insn
877	 * in the tramopline image to finish before releasing
878	 * the trampoline image.
879	 */
880	synchronize_rcu_mult(call_rcu, call_rcu_tasks);
881
882	__bpf_struct_ops_map_free(map);
883}
884
885static int bpf_struct_ops_map_alloc_check(union bpf_attr *attr)
886{
887	if (attr->key_size != sizeof(unsigned int) || attr->max_entries != 1 ||
888	    (attr->map_flags & ~(BPF_F_LINK | BPF_F_VTYPE_BTF_OBJ_FD)) ||
889	    !attr->btf_vmlinux_value_type_id)
890		return -EINVAL;
891	return 0;
892}
893
894static struct bpf_map *bpf_struct_ops_map_alloc(union bpf_attr *attr)
895{
896	const struct bpf_struct_ops_desc *st_ops_desc;
897	size_t st_map_size;
898	struct bpf_struct_ops_map *st_map;
899	const struct btf_type *t, *vt;
900	struct module *mod = NULL;
901	struct bpf_map *map;
902	struct btf *btf;
903	int ret;
904
905	if (attr->map_flags & BPF_F_VTYPE_BTF_OBJ_FD) {
906		/* The map holds btf for its whole life time. */
907		btf = btf_get_by_fd(attr->value_type_btf_obj_fd);
908		if (IS_ERR(btf))
909			return ERR_CAST(btf);
910		if (!btf_is_module(btf)) {
911			btf_put(btf);
912			return ERR_PTR(-EINVAL);
913		}
914
915		mod = btf_try_get_module(btf);
916		/* mod holds a refcnt to btf. We don't need an extra refcnt
917		 * here.
918		 */
919		btf_put(btf);
920		if (!mod)
921			return ERR_PTR(-EINVAL);
922	} else {
923		btf = bpf_get_btf_vmlinux();
924		if (IS_ERR(btf))
925			return ERR_CAST(btf);
926		if (!btf)
927			return ERR_PTR(-ENOTSUPP);
928	}
929
930	st_ops_desc = bpf_struct_ops_find_value(btf, attr->btf_vmlinux_value_type_id);
931	if (!st_ops_desc) {
932		ret = -ENOTSUPP;
933		goto errout;
934	}
935
936	vt = st_ops_desc->value_type;
937	if (attr->value_size != vt->size) {
938		ret = -EINVAL;
939		goto errout;
940	}
941
942	t = st_ops_desc->type;
943
944	st_map_size = sizeof(*st_map) +
945		/* kvalue stores the
946		 * struct bpf_struct_ops_tcp_congestions_ops
947		 */
948		(vt->size - sizeof(struct bpf_struct_ops_value));
949
950	st_map = bpf_map_area_alloc(st_map_size, NUMA_NO_NODE);
951	if (!st_map) {
952		ret = -ENOMEM;
953		goto errout;
954	}
955
956	st_map->st_ops_desc = st_ops_desc;
957	map = &st_map->map;
958
959	st_map->uvalue = bpf_map_area_alloc(vt->size, NUMA_NO_NODE);
960	st_map->links_cnt = btf_type_vlen(t);
961	st_map->links =
962		bpf_map_area_alloc(st_map->links_cnt * sizeof(struct bpf_links *),
963				   NUMA_NO_NODE);
964	if (!st_map->uvalue || !st_map->links) {
965		ret = -ENOMEM;
966		goto errout_free;
967	}
968	st_map->btf = btf;
969
970	mutex_init(&st_map->lock);
971	bpf_map_init_from_attr(map, attr);
972
973	return map;
974
975errout_free:
976	__bpf_struct_ops_map_free(map);
977errout:
978	module_put(mod);
979
980	return ERR_PTR(ret);
981}
982
983static u64 bpf_struct_ops_map_mem_usage(const struct bpf_map *map)
984{
985	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
986	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
987	const struct btf_type *vt = st_ops_desc->value_type;
988	u64 usage;
989
990	usage = sizeof(*st_map) +
991			vt->size - sizeof(struct bpf_struct_ops_value);
992	usage += vt->size;
993	usage += btf_type_vlen(vt) * sizeof(struct bpf_links *);
994	usage += PAGE_SIZE;
995	return usage;
996}
997
998BTF_ID_LIST_SINGLE(bpf_struct_ops_map_btf_ids, struct, bpf_struct_ops_map)
999const struct bpf_map_ops bpf_struct_ops_map_ops = {
1000	.map_alloc_check = bpf_struct_ops_map_alloc_check,
1001	.map_alloc = bpf_struct_ops_map_alloc,
1002	.map_free = bpf_struct_ops_map_free,
1003	.map_get_next_key = bpf_struct_ops_map_get_next_key,
1004	.map_lookup_elem = bpf_struct_ops_map_lookup_elem,
1005	.map_delete_elem = bpf_struct_ops_map_delete_elem,
1006	.map_update_elem = bpf_struct_ops_map_update_elem,
1007	.map_seq_show_elem = bpf_struct_ops_map_seq_show_elem,
1008	.map_mem_usage = bpf_struct_ops_map_mem_usage,
1009	.map_btf_id = &bpf_struct_ops_map_btf_ids[0],
1010};
1011
1012/* "const void *" because some subsystem is
1013 * passing a const (e.g. const struct tcp_congestion_ops *)
1014 */
1015bool bpf_struct_ops_get(const void *kdata)
1016{
1017	struct bpf_struct_ops_value *kvalue;
1018	struct bpf_struct_ops_map *st_map;
1019	struct bpf_map *map;
1020
1021	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
1022	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);
1023
1024	map = __bpf_map_inc_not_zero(&st_map->map, false);
1025	return !IS_ERR(map);
1026}
1027
1028void bpf_struct_ops_put(const void *kdata)
1029{
1030	struct bpf_struct_ops_value *kvalue;
1031	struct bpf_struct_ops_map *st_map;
1032
1033	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
1034	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);
1035
1036	bpf_map_put(&st_map->map);
1037}
1038
1039static bool bpf_struct_ops_valid_to_reg(struct bpf_map *map)
1040{
1041	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
1042
1043	return map->map_type == BPF_MAP_TYPE_STRUCT_OPS &&
1044		map->map_flags & BPF_F_LINK &&
1045		/* Pair with smp_store_release() during map_update */
1046		smp_load_acquire(&st_map->kvalue.common.state) == BPF_STRUCT_OPS_STATE_READY;
1047}
1048
1049static void bpf_struct_ops_map_link_dealloc(struct bpf_link *link)
1050{
1051	struct bpf_struct_ops_link *st_link;
1052	struct bpf_struct_ops_map *st_map;
1053
1054	st_link = container_of(link, struct bpf_struct_ops_link, link);
1055	st_map = (struct bpf_struct_ops_map *)
1056		rcu_dereference_protected(st_link->map, true);
1057	if (st_map) {
1058		/* st_link->map can be NULL if
1059		 * bpf_struct_ops_link_create() fails to register.
1060		 */
1061		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data);
1062		bpf_map_put(&st_map->map);
1063	}
1064	kfree(st_link);
1065}
1066
1067static void bpf_struct_ops_map_link_show_fdinfo(const struct bpf_link *link,
1068					    struct seq_file *seq)
1069{
1070	struct bpf_struct_ops_link *st_link;
1071	struct bpf_map *map;
1072
1073	st_link = container_of(link, struct bpf_struct_ops_link, link);
1074	rcu_read_lock();
1075	map = rcu_dereference(st_link->map);
1076	seq_printf(seq, "map_id:\t%d\n", map->id);
1077	rcu_read_unlock();
1078}
1079
1080static int bpf_struct_ops_map_link_fill_link_info(const struct bpf_link *link,
1081					       struct bpf_link_info *info)
1082{
1083	struct bpf_struct_ops_link *st_link;
1084	struct bpf_map *map;
1085
1086	st_link = container_of(link, struct bpf_struct_ops_link, link);
1087	rcu_read_lock();
1088	map = rcu_dereference(st_link->map);
1089	info->struct_ops.map_id = map->id;
1090	rcu_read_unlock();
1091	return 0;
1092}
1093
1094static int bpf_struct_ops_map_link_update(struct bpf_link *link, struct bpf_map *new_map,
1095					  struct bpf_map *expected_old_map)
1096{
1097	struct bpf_struct_ops_map *st_map, *old_st_map;
1098	struct bpf_map *old_map;
1099	struct bpf_struct_ops_link *st_link;
1100	int err;
1101
1102	st_link = container_of(link, struct bpf_struct_ops_link, link);
1103	st_map = container_of(new_map, struct bpf_struct_ops_map, map);
1104
1105	if (!bpf_struct_ops_valid_to_reg(new_map))
1106		return -EINVAL;
1107
1108	if (!st_map->st_ops_desc->st_ops->update)
1109		return -EOPNOTSUPP;
1110
1111	mutex_lock(&update_mutex);
1112
1113	old_map = rcu_dereference_protected(st_link->map, lockdep_is_held(&update_mutex));
1114	if (expected_old_map && old_map != expected_old_map) {
1115		err = -EPERM;
1116		goto err_out;
1117	}
1118
1119	old_st_map = container_of(old_map, struct bpf_struct_ops_map, map);
1120	/* The new and old struct_ops must be the same type. */
1121	if (st_map->st_ops_desc != old_st_map->st_ops_desc) {
1122		err = -EINVAL;
1123		goto err_out;
1124	}
1125
1126	err = st_map->st_ops_desc->st_ops->update(st_map->kvalue.data, old_st_map->kvalue.data);
1127	if (err)
1128		goto err_out;
1129
1130	bpf_map_inc(new_map);
1131	rcu_assign_pointer(st_link->map, new_map);
1132	bpf_map_put(old_map);
1133
1134err_out:
1135	mutex_unlock(&update_mutex);
1136
1137	return err;
1138}
1139
1140static const struct bpf_link_ops bpf_struct_ops_map_lops = {
1141	.dealloc = bpf_struct_ops_map_link_dealloc,
1142	.show_fdinfo = bpf_struct_ops_map_link_show_fdinfo,
1143	.fill_link_info = bpf_struct_ops_map_link_fill_link_info,
1144	.update_map = bpf_struct_ops_map_link_update,
1145};
1146
1147int bpf_struct_ops_link_create(union bpf_attr *attr)
1148{
1149	struct bpf_struct_ops_link *link = NULL;
1150	struct bpf_link_primer link_primer;
1151	struct bpf_struct_ops_map *st_map;
1152	struct bpf_map *map;
1153	int err;
1154
1155	map = bpf_map_get(attr->link_create.map_fd);
1156	if (IS_ERR(map))
1157		return PTR_ERR(map);
1158
1159	st_map = (struct bpf_struct_ops_map *)map;
1160
1161	if (!bpf_struct_ops_valid_to_reg(map)) {
1162		err = -EINVAL;
1163		goto err_out;
1164	}
1165
1166	link = kzalloc(sizeof(*link), GFP_USER);
1167	if (!link) {
1168		err = -ENOMEM;
1169		goto err_out;
1170	}
1171	bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS, &bpf_struct_ops_map_lops, NULL);
1172
1173	err = bpf_link_prime(&link->link, &link_primer);
1174	if (err)
1175		goto err_out;
1176
1177	err = st_map->st_ops_desc->st_ops->reg(st_map->kvalue.data);
1178	if (err) {
1179		bpf_link_cleanup(&link_primer);
1180		link = NULL;
1181		goto err_out;
1182	}
1183	RCU_INIT_POINTER(link->map, map);
1184
1185	return bpf_link_settle(&link_primer);
1186
1187err_out:
1188	bpf_map_put(map);
1189	kfree(link);
1190	return err;
1191}
1192
1193void bpf_map_struct_ops_info_fill(struct bpf_map_info *info, struct bpf_map *map)
1194{
1195	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
1196
1197	info->btf_vmlinux_id = btf_obj_id(st_map->btf);
1198}
1199