1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Hyper-V transport for vsock
4 *
5 * Hyper-V Sockets supplies a byte-stream based communication mechanism
6 * between the host and the VM. This driver implements the necessary
7 * support in the VM by introducing the new vsock transport.
8 *
9 * Copyright (c) 2017, Microsoft Corporation.
10 */
11#include <linux/module.h>
12#include <linux/vmalloc.h>
13#include <linux/hyperv.h>
14#include <net/sock.h>
15#include <net/af_vsock.h>
16#include <asm/hyperv-tlfs.h>
17
18/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some
19 * stricter requirements on the hv_sock ring buffer size of six 4K pages.
20 * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this
21 * limitation; but, keep the defaults the same for compat.
22 */
23#define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6)
24#define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6)
25#define RINGBUFFER_HVS_MAX_SIZE (HV_HYP_PAGE_SIZE * 64)
26
27/* The MTU is 16KB per the host side's design */
28#define HVS_MTU_SIZE		(1024 * 16)
29
30/* How long to wait for graceful shutdown of a connection */
31#define HVS_CLOSE_TIMEOUT (8 * HZ)
32
33struct vmpipe_proto_header {
34	u32 pkt_type;
35	u32 data_size;
36};
37
38/* For recv, we use the VMBus in-place packet iterator APIs to directly copy
39 * data from the ringbuffer into the userspace buffer.
40 */
41struct hvs_recv_buf {
42	/* The header before the payload data */
43	struct vmpipe_proto_header hdr;
44
45	/* The payload */
46	u8 data[HVS_MTU_SIZE];
47};
48
49/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use
50 * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the
51 * guest and the host processing as one VMBUS packet is the smallest processing
52 * unit.
53 *
54 * Note: the buffer can be eliminated in the future when we add new VMBus
55 * ringbuffer APIs that allow us to directly copy data from userspace buffer
56 * to VMBus ringbuffer.
57 */
58#define HVS_SEND_BUF_SIZE \
59		(HV_HYP_PAGE_SIZE - sizeof(struct vmpipe_proto_header))
60
61struct hvs_send_buf {
62	/* The header before the payload data */
63	struct vmpipe_proto_header hdr;
64
65	/* The payload */
66	u8 data[HVS_SEND_BUF_SIZE];
67};
68
69#define HVS_HEADER_LEN	(sizeof(struct vmpacket_descriptor) + \
70			 sizeof(struct vmpipe_proto_header))
71
72/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and
73 * __hv_pkt_iter_next().
74 */
75#define VMBUS_PKT_TRAILER_SIZE	(sizeof(u64))
76
77#define HVS_PKT_LEN(payload_len)	(HVS_HEADER_LEN + \
78					 ALIGN((payload_len), 8) + \
79					 VMBUS_PKT_TRAILER_SIZE)
80
81/* Upper bound on the size of a VMbus packet for hv_sock */
82#define HVS_MAX_PKT_SIZE	HVS_PKT_LEN(HVS_MTU_SIZE)
83
84union hvs_service_id {
85	guid_t	srv_id;
86
87	struct {
88		unsigned int svm_port;
89		unsigned char b[sizeof(guid_t) - sizeof(unsigned int)];
90	};
91};
92
93/* Per-socket state (accessed via vsk->trans) */
94struct hvsock {
95	struct vsock_sock *vsk;
96
97	guid_t vm_srv_id;
98	guid_t host_srv_id;
99
100	struct vmbus_channel *chan;
101	struct vmpacket_descriptor *recv_desc;
102
103	/* The length of the payload not delivered to userland yet */
104	u32 recv_data_len;
105	/* The offset of the payload */
106	u32 recv_data_off;
107
108	/* Have we sent the zero-length packet (FIN)? */
109	bool fin_sent;
110};
111
112/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is
113 * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here:
114 * when we write apps to connect to the host, we can only use VMADDR_CID_ANY
115 * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we
116 * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY
117 * as the local cid.
118 *
119 * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
120 * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
121 * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
122 * the below sockaddr:
123 *
124 * struct SOCKADDR_HV
125 * {
126 *    ADDRESS_FAMILY Family;
127 *    USHORT Reserved;
128 *    GUID VmId;
129 *    GUID ServiceId;
130 * };
131 * Note: VmID is not used by Linux VM and actually it isn't transmitted via
132 * VMBus, because here it's obvious the host and the VM can easily identify
133 * each other. Though the VmID is useful on the host, especially in the case
134 * of Windows container, Linux VM doesn't need it at all.
135 *
136 * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit
137 * the available GUID space of SOCKADDR_HV so that we can create a mapping
138 * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing
139 * Hyper-V Sockets apps on the host and in Linux VM is:
140 *
141 ****************************************************************************
142 * The only valid Service GUIDs, from the perspectives of both the host and *
143 * Linux VM, that can be connected by the other end, must conform to this   *
144 * format: <port>-facb-11e6-bd58-64006a7986d3.                              *
145 ****************************************************************************
146 *
147 * When we write apps on the host to connect(), the GUID ServiceID is used.
148 * When we write apps in Linux VM to connect(), we only need to specify the
149 * port and the driver will form the GUID and use that to request the host.
150 *
151 */
152
153/* 00000000-facb-11e6-bd58-64006a7986d3 */
154static const guid_t srv_id_template =
155	GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
156		  0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
157
158static bool hvs_check_transport(struct vsock_sock *vsk);
159
160static bool is_valid_srv_id(const guid_t *id)
161{
162	return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
163}
164
165static unsigned int get_port_by_srv_id(const guid_t *svr_id)
166{
167	return *((unsigned int *)svr_id);
168}
169
170static void hvs_addr_init(struct sockaddr_vm *addr, const guid_t *svr_id)
171{
172	unsigned int port = get_port_by_srv_id(svr_id);
173
174	vsock_addr_init(addr, VMADDR_CID_ANY, port);
175}
176
177static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan)
178{
179	set_channel_pending_send_size(chan,
180				      HVS_PKT_LEN(HVS_SEND_BUF_SIZE));
181
182	virt_mb();
183}
184
185static bool hvs_channel_readable(struct vmbus_channel *chan)
186{
187	u32 readable = hv_get_bytes_to_read(&chan->inbound);
188
189	/* 0-size payload means FIN */
190	return readable >= HVS_PKT_LEN(0);
191}
192
193static int hvs_channel_readable_payload(struct vmbus_channel *chan)
194{
195	u32 readable = hv_get_bytes_to_read(&chan->inbound);
196
197	if (readable > HVS_PKT_LEN(0)) {
198		/* At least we have 1 byte to read. We don't need to return
199		 * the exact readable bytes: see vsock_stream_recvmsg() ->
200		 * vsock_stream_has_data().
201		 */
202		return 1;
203	}
204
205	if (readable == HVS_PKT_LEN(0)) {
206		/* 0-size payload means FIN */
207		return 0;
208	}
209
210	/* No payload or FIN */
211	return -1;
212}
213
214static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan)
215{
216	u32 writeable = hv_get_bytes_to_write(&chan->outbound);
217	size_t ret;
218
219	/* The ringbuffer mustn't be 100% full, and we should reserve a
220	 * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
221	 * and hvs_shutdown().
222	 */
223	if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0))
224		return 0;
225
226	ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0);
227
228	return round_down(ret, 8);
229}
230
231static int __hvs_send_data(struct vmbus_channel *chan,
232			   struct vmpipe_proto_header *hdr,
233			   size_t to_write)
234{
235	hdr->pkt_type = 1;
236	hdr->data_size = to_write;
237	return vmbus_sendpacket(chan, hdr, sizeof(*hdr) + to_write,
238				0, VM_PKT_DATA_INBAND, 0);
239}
240
241static int hvs_send_data(struct vmbus_channel *chan,
242			 struct hvs_send_buf *send_buf, size_t to_write)
243{
244	return __hvs_send_data(chan, &send_buf->hdr, to_write);
245}
246
247static void hvs_channel_cb(void *ctx)
248{
249	struct sock *sk = (struct sock *)ctx;
250	struct vsock_sock *vsk = vsock_sk(sk);
251	struct hvsock *hvs = vsk->trans;
252	struct vmbus_channel *chan = hvs->chan;
253
254	if (hvs_channel_readable(chan))
255		sk->sk_data_ready(sk);
256
257	if (hv_get_bytes_to_write(&chan->outbound) > 0)
258		sk->sk_write_space(sk);
259}
260
261static void hvs_do_close_lock_held(struct vsock_sock *vsk,
262				   bool cancel_timeout)
263{
264	struct sock *sk = sk_vsock(vsk);
265
266	sock_set_flag(sk, SOCK_DONE);
267	vsk->peer_shutdown = SHUTDOWN_MASK;
268	if (vsock_stream_has_data(vsk) <= 0)
269		sk->sk_state = TCP_CLOSING;
270	sk->sk_state_change(sk);
271	if (vsk->close_work_scheduled &&
272	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
273		vsk->close_work_scheduled = false;
274		vsock_remove_sock(vsk);
275
276		/* Release the reference taken while scheduling the timeout */
277		sock_put(sk);
278	}
279}
280
281static void hvs_close_connection(struct vmbus_channel *chan)
282{
283	struct sock *sk = get_per_channel_state(chan);
284
285	lock_sock(sk);
286	hvs_do_close_lock_held(vsock_sk(sk), true);
287	release_sock(sk);
288
289	/* Release the refcnt for the channel that's opened in
290	 * hvs_open_connection().
291	 */
292	sock_put(sk);
293}
294
295static void hvs_open_connection(struct vmbus_channel *chan)
296{
297	guid_t *if_instance, *if_type;
298	unsigned char conn_from_host;
299
300	struct sockaddr_vm addr;
301	struct sock *sk, *new = NULL;
302	struct vsock_sock *vnew = NULL;
303	struct hvsock *hvs = NULL;
304	struct hvsock *hvs_new = NULL;
305	int rcvbuf;
306	int ret;
307	int sndbuf;
308
309	if_type = &chan->offermsg.offer.if_type;
310	if_instance = &chan->offermsg.offer.if_instance;
311	conn_from_host = chan->offermsg.offer.u.pipe.user_def[0];
312	if (!is_valid_srv_id(if_type))
313		return;
314
315	hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
316	sk = vsock_find_bound_socket(&addr);
317	if (!sk)
318		return;
319
320	lock_sock(sk);
321	if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
322	    (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
323		goto out;
324
325	if (conn_from_host) {
326		if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
327			goto out;
328
329		new = vsock_create_connected(sk);
330		if (!new)
331			goto out;
332
333		new->sk_state = TCP_SYN_SENT;
334		vnew = vsock_sk(new);
335
336		hvs_addr_init(&vnew->local_addr, if_type);
337
338		/* Remote peer is always the host */
339		vsock_addr_init(&vnew->remote_addr,
340				VMADDR_CID_HOST, VMADDR_PORT_ANY);
341		vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance);
342		ret = vsock_assign_transport(vnew, vsock_sk(sk));
343		/* Transport assigned (looking at remote_addr) must be the
344		 * same where we received the request.
345		 */
346		if (ret || !hvs_check_transport(vnew)) {
347			sock_put(new);
348			goto out;
349		}
350		hvs_new = vnew->trans;
351		hvs_new->chan = chan;
352	} else {
353		hvs = vsock_sk(sk)->trans;
354		hvs->chan = chan;
355	}
356
357	set_channel_read_mode(chan, HV_CALL_DIRECT);
358
359	/* Use the socket buffer sizes as hints for the VMBUS ring size. For
360	 * server side sockets, 'sk' is the parent socket and thus, this will
361	 * allow the child sockets to inherit the size from the parent. Keep
362	 * the mins to the default value and align to page size as per VMBUS
363	 * requirements.
364	 * For the max, the socket core library will limit the socket buffer
365	 * size that can be set by the user, but, since currently, the hv_sock
366	 * VMBUS ring buffer is physically contiguous allocation, restrict it
367	 * further.
368	 * Older versions of hv_sock host side code cannot handle bigger VMBUS
369	 * ring buffer size. Use the version number to limit the change to newer
370	 * versions.
371	 */
372	if (vmbus_proto_version < VERSION_WIN10_V5) {
373		sndbuf = RINGBUFFER_HVS_SND_SIZE;
374		rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
375	} else {
376		sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
377		sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
378		sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE);
379		rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
380		rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
381		rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE);
382	}
383
384	chan->max_pkt_size = HVS_MAX_PKT_SIZE;
385
386	ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb,
387			 conn_from_host ? new : sk);
388	if (ret != 0) {
389		if (conn_from_host) {
390			hvs_new->chan = NULL;
391			sock_put(new);
392		} else {
393			hvs->chan = NULL;
394		}
395		goto out;
396	}
397
398	set_per_channel_state(chan, conn_from_host ? new : sk);
399
400	/* This reference will be dropped by hvs_close_connection(). */
401	sock_hold(conn_from_host ? new : sk);
402	vmbus_set_chn_rescind_callback(chan, hvs_close_connection);
403
404	/* Set the pending send size to max packet size to always get
405	 * notifications from the host when there is enough writable space.
406	 * The host is optimized to send notifications only when the pending
407	 * size boundary is crossed, and not always.
408	 */
409	hvs_set_channel_pending_send_size(chan);
410
411	if (conn_from_host) {
412		new->sk_state = TCP_ESTABLISHED;
413		sk_acceptq_added(sk);
414
415		hvs_new->vm_srv_id = *if_type;
416		hvs_new->host_srv_id = *if_instance;
417
418		vsock_insert_connected(vnew);
419
420		vsock_enqueue_accept(sk, new);
421	} else {
422		sk->sk_state = TCP_ESTABLISHED;
423		sk->sk_socket->state = SS_CONNECTED;
424
425		vsock_insert_connected(vsock_sk(sk));
426	}
427
428	sk->sk_state_change(sk);
429
430out:
431	/* Release refcnt obtained when we called vsock_find_bound_socket() */
432	sock_put(sk);
433
434	release_sock(sk);
435}
436
437static u32 hvs_get_local_cid(void)
438{
439	return VMADDR_CID_ANY;
440}
441
442static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk)
443{
444	struct hvsock *hvs;
445	struct sock *sk = sk_vsock(vsk);
446
447	hvs = kzalloc(sizeof(*hvs), GFP_KERNEL);
448	if (!hvs)
449		return -ENOMEM;
450
451	vsk->trans = hvs;
452	hvs->vsk = vsk;
453	sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE;
454	sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
455	return 0;
456}
457
458static int hvs_connect(struct vsock_sock *vsk)
459{
460	union hvs_service_id vm, host;
461	struct hvsock *h = vsk->trans;
462
463	vm.srv_id = srv_id_template;
464	vm.svm_port = vsk->local_addr.svm_port;
465	h->vm_srv_id = vm.srv_id;
466
467	host.srv_id = srv_id_template;
468	host.svm_port = vsk->remote_addr.svm_port;
469	h->host_srv_id = host.srv_id;
470
471	return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
472}
473
474static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode)
475{
476	struct vmpipe_proto_header hdr;
477
478	if (hvs->fin_sent || !hvs->chan)
479		return;
480
481	/* It can't fail: see hvs_channel_writable_bytes(). */
482	(void)__hvs_send_data(hvs->chan, &hdr, 0);
483	hvs->fin_sent = true;
484}
485
486static int hvs_shutdown(struct vsock_sock *vsk, int mode)
487{
488	if (!(mode & SEND_SHUTDOWN))
489		return 0;
490
491	hvs_shutdown_lock_held(vsk->trans, mode);
492	return 0;
493}
494
495static void hvs_close_timeout(struct work_struct *work)
496{
497	struct vsock_sock *vsk =
498		container_of(work, struct vsock_sock, close_work.work);
499	struct sock *sk = sk_vsock(vsk);
500
501	sock_hold(sk);
502	lock_sock(sk);
503	if (!sock_flag(sk, SOCK_DONE))
504		hvs_do_close_lock_held(vsk, false);
505
506	vsk->close_work_scheduled = false;
507	release_sock(sk);
508	sock_put(sk);
509}
510
511/* Returns true, if it is safe to remove socket; false otherwise */
512static bool hvs_close_lock_held(struct vsock_sock *vsk)
513{
514	struct sock *sk = sk_vsock(vsk);
515
516	if (!(sk->sk_state == TCP_ESTABLISHED ||
517	      sk->sk_state == TCP_CLOSING))
518		return true;
519
520	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
521		hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK);
522
523	if (sock_flag(sk, SOCK_DONE))
524		return true;
525
526	/* This reference will be dropped by the delayed close routine */
527	sock_hold(sk);
528	INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout);
529	vsk->close_work_scheduled = true;
530	schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT);
531	return false;
532}
533
534static void hvs_release(struct vsock_sock *vsk)
535{
536	bool remove_sock;
537
538	remove_sock = hvs_close_lock_held(vsk);
539	if (remove_sock)
540		vsock_remove_sock(vsk);
541}
542
543static void hvs_destruct(struct vsock_sock *vsk)
544{
545	struct hvsock *hvs = vsk->trans;
546	struct vmbus_channel *chan = hvs->chan;
547
548	if (chan)
549		vmbus_hvsock_device_unregister(chan);
550
551	kfree(hvs);
552}
553
554static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
555{
556	return -EOPNOTSUPP;
557}
558
559static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
560			     size_t len, int flags)
561{
562	return -EOPNOTSUPP;
563}
564
565static int hvs_dgram_enqueue(struct vsock_sock *vsk,
566			     struct sockaddr_vm *remote, struct msghdr *msg,
567			     size_t dgram_len)
568{
569	return -EOPNOTSUPP;
570}
571
572static bool hvs_dgram_allow(u32 cid, u32 port)
573{
574	return false;
575}
576
577static int hvs_update_recv_data(struct hvsock *hvs)
578{
579	struct hvs_recv_buf *recv_buf;
580	u32 pkt_len, payload_len;
581
582	pkt_len = hv_pkt_len(hvs->recv_desc);
583
584	if (pkt_len < HVS_HEADER_LEN)
585		return -EIO;
586
587	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
588	payload_len = recv_buf->hdr.data_size;
589
590	if (payload_len > pkt_len - HVS_HEADER_LEN ||
591	    payload_len > HVS_MTU_SIZE)
592		return -EIO;
593
594	if (payload_len == 0)
595		hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;
596
597	hvs->recv_data_len = payload_len;
598	hvs->recv_data_off = 0;
599
600	return 0;
601}
602
603static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
604				  size_t len, int flags)
605{
606	struct hvsock *hvs = vsk->trans;
607	bool need_refill = !hvs->recv_desc;
608	struct hvs_recv_buf *recv_buf;
609	u32 to_read;
610	int ret;
611
612	if (flags & MSG_PEEK)
613		return -EOPNOTSUPP;
614
615	if (need_refill) {
616		hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
617		if (!hvs->recv_desc)
618			return -ENOBUFS;
619		ret = hvs_update_recv_data(hvs);
620		if (ret)
621			return ret;
622	}
623
624	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
625	to_read = min_t(u32, len, hvs->recv_data_len);
626	ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read);
627	if (ret != 0)
628		return ret;
629
630	hvs->recv_data_len -= to_read;
631	if (hvs->recv_data_len == 0) {
632		hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc);
633		if (hvs->recv_desc) {
634			ret = hvs_update_recv_data(hvs);
635			if (ret)
636				return ret;
637		}
638	} else {
639		hvs->recv_data_off += to_read;
640	}
641
642	return to_read;
643}
644
645static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
646				  size_t len)
647{
648	struct hvsock *hvs = vsk->trans;
649	struct vmbus_channel *chan = hvs->chan;
650	struct hvs_send_buf *send_buf;
651	ssize_t to_write, max_writable;
652	ssize_t ret = 0;
653	ssize_t bytes_written = 0;
654
655	BUILD_BUG_ON(sizeof(*send_buf) != HV_HYP_PAGE_SIZE);
656
657	send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
658	if (!send_buf)
659		return -ENOMEM;
660
661	/* Reader(s) could be draining data from the channel as we write.
662	 * Maximize bandwidth, by iterating until the channel is found to be
663	 * full.
664	 */
665	while (len) {
666		max_writable = hvs_channel_writable_bytes(chan);
667		if (!max_writable)
668			break;
669		to_write = min_t(ssize_t, len, max_writable);
670		to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE);
671		/* memcpy_from_msg is safe for loop as it advances the offsets
672		 * within the message iterator.
673		 */
674		ret = memcpy_from_msg(send_buf->data, msg, to_write);
675		if (ret < 0)
676			goto out;
677
678		ret = hvs_send_data(hvs->chan, send_buf, to_write);
679		if (ret < 0)
680			goto out;
681
682		bytes_written += to_write;
683		len -= to_write;
684	}
685out:
686	/* If any data has been sent, return that */
687	if (bytes_written)
688		ret = bytes_written;
689	kfree(send_buf);
690	return ret;
691}
692
693static s64 hvs_stream_has_data(struct vsock_sock *vsk)
694{
695	struct hvsock *hvs = vsk->trans;
696	s64 ret;
697
698	if (hvs->recv_data_len > 0)
699		return 1;
700
701	switch (hvs_channel_readable_payload(hvs->chan)) {
702	case 1:
703		ret = 1;
704		break;
705	case 0:
706		vsk->peer_shutdown |= SEND_SHUTDOWN;
707		ret = 0;
708		break;
709	default: /* -1 */
710		ret = 0;
711		break;
712	}
713
714	return ret;
715}
716
717static s64 hvs_stream_has_space(struct vsock_sock *vsk)
718{
719	struct hvsock *hvs = vsk->trans;
720
721	return hvs_channel_writable_bytes(hvs->chan);
722}
723
724static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk)
725{
726	return HVS_MTU_SIZE + 1;
727}
728
729static bool hvs_stream_is_active(struct vsock_sock *vsk)
730{
731	struct hvsock *hvs = vsk->trans;
732
733	return hvs->chan != NULL;
734}
735
736static bool hvs_stream_allow(u32 cid, u32 port)
737{
738	if (cid == VMADDR_CID_HOST)
739		return true;
740
741	return false;
742}
743
744static
745int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable)
746{
747	struct hvsock *hvs = vsk->trans;
748
749	*readable = hvs_channel_readable(hvs->chan);
750	return 0;
751}
752
753static
754int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable)
755{
756	*writable = hvs_stream_has_space(vsk) > 0;
757
758	return 0;
759}
760
761static
762int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target,
763			 struct vsock_transport_recv_notify_data *d)
764{
765	return 0;
766}
767
768static
769int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target,
770			      struct vsock_transport_recv_notify_data *d)
771{
772	return 0;
773}
774
775static
776int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target,
777				struct vsock_transport_recv_notify_data *d)
778{
779	return 0;
780}
781
782static
783int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target,
784				 ssize_t copied, bool data_read,
785				 struct vsock_transport_recv_notify_data *d)
786{
787	return 0;
788}
789
790static
791int hvs_notify_send_init(struct vsock_sock *vsk,
792			 struct vsock_transport_send_notify_data *d)
793{
794	return 0;
795}
796
797static
798int hvs_notify_send_pre_block(struct vsock_sock *vsk,
799			      struct vsock_transport_send_notify_data *d)
800{
801	return 0;
802}
803
804static
805int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk,
806				struct vsock_transport_send_notify_data *d)
807{
808	return 0;
809}
810
811static
812int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
813				 struct vsock_transport_send_notify_data *d)
814{
815	return 0;
816}
817
818static
819int hvs_notify_set_rcvlowat(struct vsock_sock *vsk, int val)
820{
821	return -EOPNOTSUPP;
822}
823
824static struct vsock_transport hvs_transport = {
825	.module                   = THIS_MODULE,
826
827	.get_local_cid            = hvs_get_local_cid,
828
829	.init                     = hvs_sock_init,
830	.destruct                 = hvs_destruct,
831	.release                  = hvs_release,
832	.connect                  = hvs_connect,
833	.shutdown                 = hvs_shutdown,
834
835	.dgram_bind               = hvs_dgram_bind,
836	.dgram_dequeue            = hvs_dgram_dequeue,
837	.dgram_enqueue            = hvs_dgram_enqueue,
838	.dgram_allow              = hvs_dgram_allow,
839
840	.stream_dequeue           = hvs_stream_dequeue,
841	.stream_enqueue           = hvs_stream_enqueue,
842	.stream_has_data          = hvs_stream_has_data,
843	.stream_has_space         = hvs_stream_has_space,
844	.stream_rcvhiwat          = hvs_stream_rcvhiwat,
845	.stream_is_active         = hvs_stream_is_active,
846	.stream_allow             = hvs_stream_allow,
847
848	.notify_poll_in           = hvs_notify_poll_in,
849	.notify_poll_out          = hvs_notify_poll_out,
850	.notify_recv_init         = hvs_notify_recv_init,
851	.notify_recv_pre_block    = hvs_notify_recv_pre_block,
852	.notify_recv_pre_dequeue  = hvs_notify_recv_pre_dequeue,
853	.notify_recv_post_dequeue = hvs_notify_recv_post_dequeue,
854	.notify_send_init         = hvs_notify_send_init,
855	.notify_send_pre_block    = hvs_notify_send_pre_block,
856	.notify_send_pre_enqueue  = hvs_notify_send_pre_enqueue,
857	.notify_send_post_enqueue = hvs_notify_send_post_enqueue,
858
859	.notify_set_rcvlowat      = hvs_notify_set_rcvlowat
860};
861
862static bool hvs_check_transport(struct vsock_sock *vsk)
863{
864	return vsk->transport == &hvs_transport;
865}
866
867static int hvs_probe(struct hv_device *hdev,
868		     const struct hv_vmbus_device_id *dev_id)
869{
870	struct vmbus_channel *chan = hdev->channel;
871
872	hvs_open_connection(chan);
873
874	/* Always return success to suppress the unnecessary error message
875	 * in vmbus_probe(): on error the host will rescind the device in
876	 * 30 seconds and we can do cleanup at that time in
877	 * vmbus_onoffer_rescind().
878	 */
879	return 0;
880}
881
882static void hvs_remove(struct hv_device *hdev)
883{
884	struct vmbus_channel *chan = hdev->channel;
885
886	vmbus_close(chan);
887}
888
889/* hv_sock connections can not persist across hibernation, and all the hv_sock
890 * channels are forced to be rescinded before hibernation: see
891 * vmbus_bus_suspend(). Here the dummy hvs_suspend() and hvs_resume()
892 * are only needed because hibernation requires that every vmbus device's
893 * driver should have a .suspend and .resume callback: see vmbus_suspend().
894 */
895static int hvs_suspend(struct hv_device *hv_dev)
896{
897	/* Dummy */
898	return 0;
899}
900
901static int hvs_resume(struct hv_device *dev)
902{
903	/* Dummy */
904	return 0;
905}
906
907/* This isn't really used. See vmbus_match() and vmbus_probe() */
908static const struct hv_vmbus_device_id id_table[] = {
909	{},
910};
911
912static struct hv_driver hvs_drv = {
913	.name		= "hv_sock",
914	.hvsock		= true,
915	.id_table	= id_table,
916	.probe		= hvs_probe,
917	.remove		= hvs_remove,
918	.suspend	= hvs_suspend,
919	.resume		= hvs_resume,
920};
921
922static int __init hvs_init(void)
923{
924	int ret;
925
926	if (vmbus_proto_version < VERSION_WIN10)
927		return -ENODEV;
928
929	ret = vmbus_driver_register(&hvs_drv);
930	if (ret != 0)
931		return ret;
932
933	ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H);
934	if (ret) {
935		vmbus_driver_unregister(&hvs_drv);
936		return ret;
937	}
938
939	return 0;
940}
941
942static void __exit hvs_exit(void)
943{
944	vsock_core_unregister(&hvs_transport);
945	vmbus_driver_unregister(&hvs_drv);
946}
947
948module_init(hvs_init);
949module_exit(hvs_exit);
950
951MODULE_DESCRIPTION("Hyper-V Sockets");
952MODULE_VERSION("1.0.0");
953MODULE_LICENSE("GPL");
954MODULE_ALIAS_NETPROTO(PF_VSOCK);
955