1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 *   Copyright (C) 2017, Microsoft Corporation.
4 *   Copyright (C) 2018, LG Electronics.
5 *
6 *   Author(s): Long Li <longli@microsoft.com>,
7 *		Hyunchul Lee <hyc.lee@gmail.com>
8 */
9
10#define SUBMOD_NAME	"smb_direct"
11
12#include <linux/kthread.h>
13#include <linux/list.h>
14#include <linux/mempool.h>
15#include <linux/highmem.h>
16#include <linux/scatterlist.h>
17#include <rdma/ib_verbs.h>
18#include <rdma/rdma_cm.h>
19#include <rdma/rw.h>
20
21#include "glob.h"
22#include "connection.h"
23#include "smb_common.h"
24#include "smbstatus.h"
25#include "transport_rdma.h"
26
27#define SMB_DIRECT_PORT_IWARP		5445
28#define SMB_DIRECT_PORT_INFINIBAND	445
29
30#define SMB_DIRECT_VERSION_LE		cpu_to_le16(0x0100)
31
32/* SMB_DIRECT negotiation timeout in seconds */
33#define SMB_DIRECT_NEGOTIATE_TIMEOUT		120
34
35#define SMB_DIRECT_MAX_SEND_SGES		6
36#define SMB_DIRECT_MAX_RECV_SGES		1
37
38/*
39 * Default maximum number of RDMA read/write outstanding on this connection
40 * This value is possibly decreased during QP creation on hardware limit
41 */
42#define SMB_DIRECT_CM_INITIATOR_DEPTH		8
43
44/* Maximum number of retries on data transfer operations */
45#define SMB_DIRECT_CM_RETRY			6
46/* No need to retry on Receiver Not Ready since SMB_DIRECT manages credits */
47#define SMB_DIRECT_CM_RNR_RETRY		0
48
49/*
50 * User configurable initial values per SMB_DIRECT transport connection
51 * as defined in [MS-SMBD] 3.1.1.1
52 * Those may change after a SMB_DIRECT negotiation
53 */
54
55/* Set 445 port to SMB Direct port by default */
56static int smb_direct_port = SMB_DIRECT_PORT_INFINIBAND;
57
58/* The local peer's maximum number of credits to grant to the peer */
59static int smb_direct_receive_credit_max = 255;
60
61/* The remote peer's credit request of local peer */
62static int smb_direct_send_credit_target = 255;
63
64/* The maximum single message size can be sent to remote peer */
65static int smb_direct_max_send_size = 1364;
66
67/*  The maximum fragmented upper-layer payload receive size supported */
68static int smb_direct_max_fragmented_recv_size = 1024 * 1024;
69
70/*  The maximum single-message size which can be received */
71static int smb_direct_max_receive_size = 1364;
72
73static int smb_direct_max_read_write_size = SMBD_DEFAULT_IOSIZE;
74
75static LIST_HEAD(smb_direct_device_list);
76static DEFINE_RWLOCK(smb_direct_device_lock);
77
78struct smb_direct_device {
79	struct ib_device	*ib_dev;
80	struct list_head	list;
81};
82
83static struct smb_direct_listener {
84	struct rdma_cm_id	*cm_id;
85} smb_direct_listener;
86
87static struct workqueue_struct *smb_direct_wq;
88
89enum smb_direct_status {
90	SMB_DIRECT_CS_NEW = 0,
91	SMB_DIRECT_CS_CONNECTED,
92	SMB_DIRECT_CS_DISCONNECTING,
93	SMB_DIRECT_CS_DISCONNECTED,
94};
95
96struct smb_direct_transport {
97	struct ksmbd_transport	transport;
98
99	enum smb_direct_status	status;
100	bool			full_packet_received;
101	wait_queue_head_t	wait_status;
102
103	struct rdma_cm_id	*cm_id;
104	struct ib_cq		*send_cq;
105	struct ib_cq		*recv_cq;
106	struct ib_pd		*pd;
107	struct ib_qp		*qp;
108
109	int			max_send_size;
110	int			max_recv_size;
111	int			max_fragmented_send_size;
112	int			max_fragmented_recv_size;
113	int			max_rdma_rw_size;
114
115	spinlock_t		reassembly_queue_lock;
116	struct list_head	reassembly_queue;
117	int			reassembly_data_length;
118	int			reassembly_queue_length;
119	int			first_entry_offset;
120	wait_queue_head_t	wait_reassembly_queue;
121
122	spinlock_t		receive_credit_lock;
123	int			recv_credits;
124	int			count_avail_recvmsg;
125	int			recv_credit_max;
126	int			recv_credit_target;
127
128	spinlock_t		recvmsg_queue_lock;
129	struct list_head	recvmsg_queue;
130
131	spinlock_t		empty_recvmsg_queue_lock;
132	struct list_head	empty_recvmsg_queue;
133
134	int			send_credit_target;
135	atomic_t		send_credits;
136	spinlock_t		lock_new_recv_credits;
137	int			new_recv_credits;
138	int			max_rw_credits;
139	int			pages_per_rw_credit;
140	atomic_t		rw_credits;
141
142	wait_queue_head_t	wait_send_credits;
143	wait_queue_head_t	wait_rw_credits;
144
145	mempool_t		*sendmsg_mempool;
146	struct kmem_cache	*sendmsg_cache;
147	mempool_t		*recvmsg_mempool;
148	struct kmem_cache	*recvmsg_cache;
149
150	wait_queue_head_t	wait_send_pending;
151	atomic_t		send_pending;
152
153	struct delayed_work	post_recv_credits_work;
154	struct work_struct	send_immediate_work;
155	struct work_struct	disconnect_work;
156
157	bool			negotiation_requested;
158};
159
160#define KSMBD_TRANS(t) ((struct ksmbd_transport *)&((t)->transport))
161
162enum {
163	SMB_DIRECT_MSG_NEGOTIATE_REQ = 0,
164	SMB_DIRECT_MSG_DATA_TRANSFER
165};
166
167static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops;
168
169struct smb_direct_send_ctx {
170	struct list_head	msg_list;
171	int			wr_cnt;
172	bool			need_invalidate_rkey;
173	unsigned int		remote_key;
174};
175
176struct smb_direct_sendmsg {
177	struct smb_direct_transport	*transport;
178	struct ib_send_wr	wr;
179	struct list_head	list;
180	int			num_sge;
181	struct ib_sge		sge[SMB_DIRECT_MAX_SEND_SGES];
182	struct ib_cqe		cqe;
183	u8			packet[];
184};
185
186struct smb_direct_recvmsg {
187	struct smb_direct_transport	*transport;
188	struct list_head	list;
189	int			type;
190	struct ib_sge		sge;
191	struct ib_cqe		cqe;
192	bool			first_segment;
193	u8			packet[];
194};
195
196struct smb_direct_rdma_rw_msg {
197	struct smb_direct_transport	*t;
198	struct ib_cqe		cqe;
199	int			status;
200	struct completion	*completion;
201	struct list_head	list;
202	struct rdma_rw_ctx	rw_ctx;
203	struct sg_table		sgt;
204	struct scatterlist	sg_list[];
205};
206
207void init_smbd_max_io_size(unsigned int sz)
208{
209	sz = clamp_val(sz, SMBD_MIN_IOSIZE, SMBD_MAX_IOSIZE);
210	smb_direct_max_read_write_size = sz;
211}
212
213unsigned int get_smbd_max_read_write_size(void)
214{
215	return smb_direct_max_read_write_size;
216}
217
218static inline int get_buf_page_count(void *buf, int size)
219{
220	return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) -
221		(uintptr_t)buf / PAGE_SIZE;
222}
223
224static void smb_direct_destroy_pools(struct smb_direct_transport *transport);
225static void smb_direct_post_recv_credits(struct work_struct *work);
226static int smb_direct_post_send_data(struct smb_direct_transport *t,
227				     struct smb_direct_send_ctx *send_ctx,
228				     struct kvec *iov, int niov,
229				     int remaining_data_length);
230
231static inline struct smb_direct_transport *
232smb_trans_direct_transfort(struct ksmbd_transport *t)
233{
234	return container_of(t, struct smb_direct_transport, transport);
235}
236
237static inline void
238*smb_direct_recvmsg_payload(struct smb_direct_recvmsg *recvmsg)
239{
240	return (void *)recvmsg->packet;
241}
242
243static inline bool is_receive_credit_post_required(int receive_credits,
244						   int avail_recvmsg_count)
245{
246	return receive_credits <= (smb_direct_receive_credit_max >> 3) &&
247		avail_recvmsg_count >= (receive_credits >> 2);
248}
249
250static struct
251smb_direct_recvmsg *get_free_recvmsg(struct smb_direct_transport *t)
252{
253	struct smb_direct_recvmsg *recvmsg = NULL;
254
255	spin_lock(&t->recvmsg_queue_lock);
256	if (!list_empty(&t->recvmsg_queue)) {
257		recvmsg = list_first_entry(&t->recvmsg_queue,
258					   struct smb_direct_recvmsg,
259					   list);
260		list_del(&recvmsg->list);
261	}
262	spin_unlock(&t->recvmsg_queue_lock);
263	return recvmsg;
264}
265
266static void put_recvmsg(struct smb_direct_transport *t,
267			struct smb_direct_recvmsg *recvmsg)
268{
269	ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
270			    recvmsg->sge.length, DMA_FROM_DEVICE);
271
272	spin_lock(&t->recvmsg_queue_lock);
273	list_add(&recvmsg->list, &t->recvmsg_queue);
274	spin_unlock(&t->recvmsg_queue_lock);
275}
276
277static struct
278smb_direct_recvmsg *get_empty_recvmsg(struct smb_direct_transport *t)
279{
280	struct smb_direct_recvmsg *recvmsg = NULL;
281
282	spin_lock(&t->empty_recvmsg_queue_lock);
283	if (!list_empty(&t->empty_recvmsg_queue)) {
284		recvmsg = list_first_entry(&t->empty_recvmsg_queue,
285					   struct smb_direct_recvmsg, list);
286		list_del(&recvmsg->list);
287	}
288	spin_unlock(&t->empty_recvmsg_queue_lock);
289	return recvmsg;
290}
291
292static void put_empty_recvmsg(struct smb_direct_transport *t,
293			      struct smb_direct_recvmsg *recvmsg)
294{
295	ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
296			    recvmsg->sge.length, DMA_FROM_DEVICE);
297
298	spin_lock(&t->empty_recvmsg_queue_lock);
299	list_add_tail(&recvmsg->list, &t->empty_recvmsg_queue);
300	spin_unlock(&t->empty_recvmsg_queue_lock);
301}
302
303static void enqueue_reassembly(struct smb_direct_transport *t,
304			       struct smb_direct_recvmsg *recvmsg,
305			       int data_length)
306{
307	spin_lock(&t->reassembly_queue_lock);
308	list_add_tail(&recvmsg->list, &t->reassembly_queue);
309	t->reassembly_queue_length++;
310	/*
311	 * Make sure reassembly_data_length is updated after list and
312	 * reassembly_queue_length are updated. On the dequeue side
313	 * reassembly_data_length is checked without a lock to determine
314	 * if reassembly_queue_length and list is up to date
315	 */
316	virt_wmb();
317	t->reassembly_data_length += data_length;
318	spin_unlock(&t->reassembly_queue_lock);
319}
320
321static struct smb_direct_recvmsg *get_first_reassembly(struct smb_direct_transport *t)
322{
323	if (!list_empty(&t->reassembly_queue))
324		return list_first_entry(&t->reassembly_queue,
325				struct smb_direct_recvmsg, list);
326	else
327		return NULL;
328}
329
330static void smb_direct_disconnect_rdma_work(struct work_struct *work)
331{
332	struct smb_direct_transport *t =
333		container_of(work, struct smb_direct_transport,
334			     disconnect_work);
335
336	if (t->status == SMB_DIRECT_CS_CONNECTED) {
337		t->status = SMB_DIRECT_CS_DISCONNECTING;
338		rdma_disconnect(t->cm_id);
339	}
340}
341
342static void
343smb_direct_disconnect_rdma_connection(struct smb_direct_transport *t)
344{
345	if (t->status == SMB_DIRECT_CS_CONNECTED)
346		queue_work(smb_direct_wq, &t->disconnect_work);
347}
348
349static void smb_direct_send_immediate_work(struct work_struct *work)
350{
351	struct smb_direct_transport *t = container_of(work,
352			struct smb_direct_transport, send_immediate_work);
353
354	if (t->status != SMB_DIRECT_CS_CONNECTED)
355		return;
356
357	smb_direct_post_send_data(t, NULL, NULL, 0, 0);
358}
359
360static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
361{
362	struct smb_direct_transport *t;
363	struct ksmbd_conn *conn;
364
365	t = kzalloc(sizeof(*t), GFP_KERNEL);
366	if (!t)
367		return NULL;
368
369	t->cm_id = cm_id;
370	cm_id->context = t;
371
372	t->status = SMB_DIRECT_CS_NEW;
373	init_waitqueue_head(&t->wait_status);
374
375	spin_lock_init(&t->reassembly_queue_lock);
376	INIT_LIST_HEAD(&t->reassembly_queue);
377	t->reassembly_data_length = 0;
378	t->reassembly_queue_length = 0;
379	init_waitqueue_head(&t->wait_reassembly_queue);
380	init_waitqueue_head(&t->wait_send_credits);
381	init_waitqueue_head(&t->wait_rw_credits);
382
383	spin_lock_init(&t->receive_credit_lock);
384	spin_lock_init(&t->recvmsg_queue_lock);
385	INIT_LIST_HEAD(&t->recvmsg_queue);
386
387	spin_lock_init(&t->empty_recvmsg_queue_lock);
388	INIT_LIST_HEAD(&t->empty_recvmsg_queue);
389
390	init_waitqueue_head(&t->wait_send_pending);
391	atomic_set(&t->send_pending, 0);
392
393	spin_lock_init(&t->lock_new_recv_credits);
394
395	INIT_DELAYED_WORK(&t->post_recv_credits_work,
396			  smb_direct_post_recv_credits);
397	INIT_WORK(&t->send_immediate_work, smb_direct_send_immediate_work);
398	INIT_WORK(&t->disconnect_work, smb_direct_disconnect_rdma_work);
399
400	conn = ksmbd_conn_alloc();
401	if (!conn)
402		goto err;
403	conn->transport = KSMBD_TRANS(t);
404	KSMBD_TRANS(t)->conn = conn;
405	KSMBD_TRANS(t)->ops = &ksmbd_smb_direct_transport_ops;
406	return t;
407err:
408	kfree(t);
409	return NULL;
410}
411
412static void free_transport(struct smb_direct_transport *t)
413{
414	struct smb_direct_recvmsg *recvmsg;
415
416	wake_up_interruptible(&t->wait_send_credits);
417
418	ksmbd_debug(RDMA, "wait for all send posted to IB to finish\n");
419	wait_event(t->wait_send_pending,
420		   atomic_read(&t->send_pending) == 0);
421
422	cancel_work_sync(&t->disconnect_work);
423	cancel_delayed_work_sync(&t->post_recv_credits_work);
424	cancel_work_sync(&t->send_immediate_work);
425
426	if (t->qp) {
427		ib_drain_qp(t->qp);
428		ib_mr_pool_destroy(t->qp, &t->qp->rdma_mrs);
429		ib_destroy_qp(t->qp);
430	}
431
432	ksmbd_debug(RDMA, "drain the reassembly queue\n");
433	do {
434		spin_lock(&t->reassembly_queue_lock);
435		recvmsg = get_first_reassembly(t);
436		if (recvmsg) {
437			list_del(&recvmsg->list);
438			spin_unlock(&t->reassembly_queue_lock);
439			put_recvmsg(t, recvmsg);
440		} else {
441			spin_unlock(&t->reassembly_queue_lock);
442		}
443	} while (recvmsg);
444	t->reassembly_data_length = 0;
445
446	if (t->send_cq)
447		ib_free_cq(t->send_cq);
448	if (t->recv_cq)
449		ib_free_cq(t->recv_cq);
450	if (t->pd)
451		ib_dealloc_pd(t->pd);
452	if (t->cm_id)
453		rdma_destroy_id(t->cm_id);
454
455	smb_direct_destroy_pools(t);
456	ksmbd_conn_free(KSMBD_TRANS(t)->conn);
457	kfree(t);
458}
459
460static struct smb_direct_sendmsg
461*smb_direct_alloc_sendmsg(struct smb_direct_transport *t)
462{
463	struct smb_direct_sendmsg *msg;
464
465	msg = mempool_alloc(t->sendmsg_mempool, GFP_KERNEL);
466	if (!msg)
467		return ERR_PTR(-ENOMEM);
468	msg->transport = t;
469	INIT_LIST_HEAD(&msg->list);
470	msg->num_sge = 0;
471	return msg;
472}
473
474static void smb_direct_free_sendmsg(struct smb_direct_transport *t,
475				    struct smb_direct_sendmsg *msg)
476{
477	int i;
478
479	if (msg->num_sge > 0) {
480		ib_dma_unmap_single(t->cm_id->device,
481				    msg->sge[0].addr, msg->sge[0].length,
482				    DMA_TO_DEVICE);
483		for (i = 1; i < msg->num_sge; i++)
484			ib_dma_unmap_page(t->cm_id->device,
485					  msg->sge[i].addr, msg->sge[i].length,
486					  DMA_TO_DEVICE);
487	}
488	mempool_free(msg, t->sendmsg_mempool);
489}
490
491static int smb_direct_check_recvmsg(struct smb_direct_recvmsg *recvmsg)
492{
493	switch (recvmsg->type) {
494	case SMB_DIRECT_MSG_DATA_TRANSFER: {
495		struct smb_direct_data_transfer *req =
496			(struct smb_direct_data_transfer *)recvmsg->packet;
497		struct smb2_hdr *hdr = (struct smb2_hdr *)(recvmsg->packet
498				+ le32_to_cpu(req->data_offset));
499		ksmbd_debug(RDMA,
500			    "CreditGranted: %u, CreditRequested: %u, DataLength: %u, RemainingDataLength: %u, SMB: %x, Command: %u\n",
501			    le16_to_cpu(req->credits_granted),
502			    le16_to_cpu(req->credits_requested),
503			    req->data_length, req->remaining_data_length,
504			    hdr->ProtocolId, hdr->Command);
505		break;
506	}
507	case SMB_DIRECT_MSG_NEGOTIATE_REQ: {
508		struct smb_direct_negotiate_req *req =
509			(struct smb_direct_negotiate_req *)recvmsg->packet;
510		ksmbd_debug(RDMA,
511			    "MinVersion: %u, MaxVersion: %u, CreditRequested: %u, MaxSendSize: %u, MaxRecvSize: %u, MaxFragmentedSize: %u\n",
512			    le16_to_cpu(req->min_version),
513			    le16_to_cpu(req->max_version),
514			    le16_to_cpu(req->credits_requested),
515			    le32_to_cpu(req->preferred_send_size),
516			    le32_to_cpu(req->max_receive_size),
517			    le32_to_cpu(req->max_fragmented_size));
518		if (le16_to_cpu(req->min_version) > 0x0100 ||
519		    le16_to_cpu(req->max_version) < 0x0100)
520			return -EOPNOTSUPP;
521		if (le16_to_cpu(req->credits_requested) <= 0 ||
522		    le32_to_cpu(req->max_receive_size) <= 128 ||
523		    le32_to_cpu(req->max_fragmented_size) <=
524					128 * 1024)
525			return -ECONNABORTED;
526
527		break;
528	}
529	default:
530		return -EINVAL;
531	}
532	return 0;
533}
534
535static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
536{
537	struct smb_direct_recvmsg *recvmsg;
538	struct smb_direct_transport *t;
539
540	recvmsg = container_of(wc->wr_cqe, struct smb_direct_recvmsg, cqe);
541	t = recvmsg->transport;
542
543	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_RECV) {
544		if (wc->status != IB_WC_WR_FLUSH_ERR) {
545			pr_err("Recv error. status='%s (%d)' opcode=%d\n",
546			       ib_wc_status_msg(wc->status), wc->status,
547			       wc->opcode);
548			smb_direct_disconnect_rdma_connection(t);
549		}
550		put_empty_recvmsg(t, recvmsg);
551		return;
552	}
553
554	ksmbd_debug(RDMA, "Recv completed. status='%s (%d)', opcode=%d\n",
555		    ib_wc_status_msg(wc->status), wc->status,
556		    wc->opcode);
557
558	ib_dma_sync_single_for_cpu(wc->qp->device, recvmsg->sge.addr,
559				   recvmsg->sge.length, DMA_FROM_DEVICE);
560
561	switch (recvmsg->type) {
562	case SMB_DIRECT_MSG_NEGOTIATE_REQ:
563		if (wc->byte_len < sizeof(struct smb_direct_negotiate_req)) {
564			put_empty_recvmsg(t, recvmsg);
565			return;
566		}
567		t->negotiation_requested = true;
568		t->full_packet_received = true;
569		t->status = SMB_DIRECT_CS_CONNECTED;
570		enqueue_reassembly(t, recvmsg, 0);
571		wake_up_interruptible(&t->wait_status);
572		break;
573	case SMB_DIRECT_MSG_DATA_TRANSFER: {
574		struct smb_direct_data_transfer *data_transfer =
575			(struct smb_direct_data_transfer *)recvmsg->packet;
576		unsigned int data_length;
577		int avail_recvmsg_count, receive_credits;
578
579		if (wc->byte_len <
580		    offsetof(struct smb_direct_data_transfer, padding)) {
581			put_empty_recvmsg(t, recvmsg);
582			return;
583		}
584
585		data_length = le32_to_cpu(data_transfer->data_length);
586		if (data_length) {
587			if (wc->byte_len < sizeof(struct smb_direct_data_transfer) +
588			    (u64)data_length) {
589				put_empty_recvmsg(t, recvmsg);
590				return;
591			}
592
593			if (t->full_packet_received)
594				recvmsg->first_segment = true;
595
596			if (le32_to_cpu(data_transfer->remaining_data_length))
597				t->full_packet_received = false;
598			else
599				t->full_packet_received = true;
600
601			enqueue_reassembly(t, recvmsg, (int)data_length);
602			wake_up_interruptible(&t->wait_reassembly_queue);
603
604			spin_lock(&t->receive_credit_lock);
605			receive_credits = --(t->recv_credits);
606			avail_recvmsg_count = t->count_avail_recvmsg;
607			spin_unlock(&t->receive_credit_lock);
608		} else {
609			put_empty_recvmsg(t, recvmsg);
610
611			spin_lock(&t->receive_credit_lock);
612			receive_credits = --(t->recv_credits);
613			avail_recvmsg_count = ++(t->count_avail_recvmsg);
614			spin_unlock(&t->receive_credit_lock);
615		}
616
617		t->recv_credit_target =
618				le16_to_cpu(data_transfer->credits_requested);
619		atomic_add(le16_to_cpu(data_transfer->credits_granted),
620			   &t->send_credits);
621
622		if (le16_to_cpu(data_transfer->flags) &
623		    SMB_DIRECT_RESPONSE_REQUESTED)
624			queue_work(smb_direct_wq, &t->send_immediate_work);
625
626		if (atomic_read(&t->send_credits) > 0)
627			wake_up_interruptible(&t->wait_send_credits);
628
629		if (is_receive_credit_post_required(receive_credits, avail_recvmsg_count))
630			mod_delayed_work(smb_direct_wq,
631					 &t->post_recv_credits_work, 0);
632		break;
633	}
634	default:
635		break;
636	}
637}
638
639static int smb_direct_post_recv(struct smb_direct_transport *t,
640				struct smb_direct_recvmsg *recvmsg)
641{
642	struct ib_recv_wr wr;
643	int ret;
644
645	recvmsg->sge.addr = ib_dma_map_single(t->cm_id->device,
646					      recvmsg->packet, t->max_recv_size,
647					      DMA_FROM_DEVICE);
648	ret = ib_dma_mapping_error(t->cm_id->device, recvmsg->sge.addr);
649	if (ret)
650		return ret;
651	recvmsg->sge.length = t->max_recv_size;
652	recvmsg->sge.lkey = t->pd->local_dma_lkey;
653	recvmsg->cqe.done = recv_done;
654
655	wr.wr_cqe = &recvmsg->cqe;
656	wr.next = NULL;
657	wr.sg_list = &recvmsg->sge;
658	wr.num_sge = 1;
659
660	ret = ib_post_recv(t->qp, &wr, NULL);
661	if (ret) {
662		pr_err("Can't post recv: %d\n", ret);
663		ib_dma_unmap_single(t->cm_id->device,
664				    recvmsg->sge.addr, recvmsg->sge.length,
665				    DMA_FROM_DEVICE);
666		smb_direct_disconnect_rdma_connection(t);
667		return ret;
668	}
669	return ret;
670}
671
672static int smb_direct_read(struct ksmbd_transport *t, char *buf,
673			   unsigned int size, int unused)
674{
675	struct smb_direct_recvmsg *recvmsg;
676	struct smb_direct_data_transfer *data_transfer;
677	int to_copy, to_read, data_read, offset;
678	u32 data_length, remaining_data_length, data_offset;
679	int rc;
680	struct smb_direct_transport *st = smb_trans_direct_transfort(t);
681
682again:
683	if (st->status != SMB_DIRECT_CS_CONNECTED) {
684		pr_err("disconnected\n");
685		return -ENOTCONN;
686	}
687
688	/*
689	 * No need to hold the reassembly queue lock all the time as we are
690	 * the only one reading from the front of the queue. The transport
691	 * may add more entries to the back of the queue at the same time
692	 */
693	if (st->reassembly_data_length >= size) {
694		int queue_length;
695		int queue_removed = 0;
696
697		/*
698		 * Need to make sure reassembly_data_length is read before
699		 * reading reassembly_queue_length and calling
700		 * get_first_reassembly. This call is lock free
701		 * as we never read at the end of the queue which are being
702		 * updated in SOFTIRQ as more data is received
703		 */
704		virt_rmb();
705		queue_length = st->reassembly_queue_length;
706		data_read = 0;
707		to_read = size;
708		offset = st->first_entry_offset;
709		while (data_read < size) {
710			recvmsg = get_first_reassembly(st);
711			data_transfer = smb_direct_recvmsg_payload(recvmsg);
712			data_length = le32_to_cpu(data_transfer->data_length);
713			remaining_data_length =
714				le32_to_cpu(data_transfer->remaining_data_length);
715			data_offset = le32_to_cpu(data_transfer->data_offset);
716
717			/*
718			 * The upper layer expects RFC1002 length at the
719			 * beginning of the payload. Return it to indicate
720			 * the total length of the packet. This minimize the
721			 * change to upper layer packet processing logic. This
722			 * will be eventually remove when an intermediate
723			 * transport layer is added
724			 */
725			if (recvmsg->first_segment && size == 4) {
726				unsigned int rfc1002_len =
727					data_length + remaining_data_length;
728				*((__be32 *)buf) = cpu_to_be32(rfc1002_len);
729				data_read = 4;
730				recvmsg->first_segment = false;
731				ksmbd_debug(RDMA,
732					    "returning rfc1002 length %d\n",
733					    rfc1002_len);
734				goto read_rfc1002_done;
735			}
736
737			to_copy = min_t(int, data_length - offset, to_read);
738			memcpy(buf + data_read, (char *)data_transfer + data_offset + offset,
739			       to_copy);
740
741			/* move on to the next buffer? */
742			if (to_copy == data_length - offset) {
743				queue_length--;
744				/*
745				 * No need to lock if we are not at the
746				 * end of the queue
747				 */
748				if (queue_length) {
749					list_del(&recvmsg->list);
750				} else {
751					spin_lock_irq(&st->reassembly_queue_lock);
752					list_del(&recvmsg->list);
753					spin_unlock_irq(&st->reassembly_queue_lock);
754				}
755				queue_removed++;
756				put_recvmsg(st, recvmsg);
757				offset = 0;
758			} else {
759				offset += to_copy;
760			}
761
762			to_read -= to_copy;
763			data_read += to_copy;
764		}
765
766		spin_lock_irq(&st->reassembly_queue_lock);
767		st->reassembly_data_length -= data_read;
768		st->reassembly_queue_length -= queue_removed;
769		spin_unlock_irq(&st->reassembly_queue_lock);
770
771		spin_lock(&st->receive_credit_lock);
772		st->count_avail_recvmsg += queue_removed;
773		if (is_receive_credit_post_required(st->recv_credits, st->count_avail_recvmsg)) {
774			spin_unlock(&st->receive_credit_lock);
775			mod_delayed_work(smb_direct_wq,
776					 &st->post_recv_credits_work, 0);
777		} else {
778			spin_unlock(&st->receive_credit_lock);
779		}
780
781		st->first_entry_offset = offset;
782		ksmbd_debug(RDMA,
783			    "returning to thread data_read=%d reassembly_data_length=%d first_entry_offset=%d\n",
784			    data_read, st->reassembly_data_length,
785			    st->first_entry_offset);
786read_rfc1002_done:
787		return data_read;
788	}
789
790	ksmbd_debug(RDMA, "wait_event on more data\n");
791	rc = wait_event_interruptible(st->wait_reassembly_queue,
792				      st->reassembly_data_length >= size ||
793				       st->status != SMB_DIRECT_CS_CONNECTED);
794	if (rc)
795		return -EINTR;
796
797	goto again;
798}
799
800static void smb_direct_post_recv_credits(struct work_struct *work)
801{
802	struct smb_direct_transport *t = container_of(work,
803		struct smb_direct_transport, post_recv_credits_work.work);
804	struct smb_direct_recvmsg *recvmsg;
805	int receive_credits, credits = 0;
806	int ret;
807	int use_free = 1;
808
809	spin_lock(&t->receive_credit_lock);
810	receive_credits = t->recv_credits;
811	spin_unlock(&t->receive_credit_lock);
812
813	if (receive_credits < t->recv_credit_target) {
814		while (true) {
815			if (use_free)
816				recvmsg = get_free_recvmsg(t);
817			else
818				recvmsg = get_empty_recvmsg(t);
819			if (!recvmsg) {
820				if (use_free) {
821					use_free = 0;
822					continue;
823				} else {
824					break;
825				}
826			}
827
828			recvmsg->type = SMB_DIRECT_MSG_DATA_TRANSFER;
829			recvmsg->first_segment = false;
830
831			ret = smb_direct_post_recv(t, recvmsg);
832			if (ret) {
833				pr_err("Can't post recv: %d\n", ret);
834				put_recvmsg(t, recvmsg);
835				break;
836			}
837			credits++;
838		}
839	}
840
841	spin_lock(&t->receive_credit_lock);
842	t->recv_credits += credits;
843	t->count_avail_recvmsg -= credits;
844	spin_unlock(&t->receive_credit_lock);
845
846	spin_lock(&t->lock_new_recv_credits);
847	t->new_recv_credits += credits;
848	spin_unlock(&t->lock_new_recv_credits);
849
850	if (credits)
851		queue_work(smb_direct_wq, &t->send_immediate_work);
852}
853
854static void send_done(struct ib_cq *cq, struct ib_wc *wc)
855{
856	struct smb_direct_sendmsg *sendmsg, *sibling;
857	struct smb_direct_transport *t;
858	struct list_head *pos, *prev, *end;
859
860	sendmsg = container_of(wc->wr_cqe, struct smb_direct_sendmsg, cqe);
861	t = sendmsg->transport;
862
863	ksmbd_debug(RDMA, "Send completed. status='%s (%d)', opcode=%d\n",
864		    ib_wc_status_msg(wc->status), wc->status,
865		    wc->opcode);
866
867	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
868		pr_err("Send error. status='%s (%d)', opcode=%d\n",
869		       ib_wc_status_msg(wc->status), wc->status,
870		       wc->opcode);
871		smb_direct_disconnect_rdma_connection(t);
872	}
873
874	if (atomic_dec_and_test(&t->send_pending))
875		wake_up(&t->wait_send_pending);
876
877	/* iterate and free the list of messages in reverse. the list's head
878	 * is invalid.
879	 */
880	for (pos = &sendmsg->list, prev = pos->prev, end = sendmsg->list.next;
881	     prev != end; pos = prev, prev = prev->prev) {
882		sibling = container_of(pos, struct smb_direct_sendmsg, list);
883		smb_direct_free_sendmsg(t, sibling);
884	}
885
886	sibling = container_of(pos, struct smb_direct_sendmsg, list);
887	smb_direct_free_sendmsg(t, sibling);
888}
889
890static int manage_credits_prior_sending(struct smb_direct_transport *t)
891{
892	int new_credits;
893
894	spin_lock(&t->lock_new_recv_credits);
895	new_credits = t->new_recv_credits;
896	t->new_recv_credits = 0;
897	spin_unlock(&t->lock_new_recv_credits);
898
899	return new_credits;
900}
901
902static int smb_direct_post_send(struct smb_direct_transport *t,
903				struct ib_send_wr *wr)
904{
905	int ret;
906
907	atomic_inc(&t->send_pending);
908	ret = ib_post_send(t->qp, wr, NULL);
909	if (ret) {
910		pr_err("failed to post send: %d\n", ret);
911		if (atomic_dec_and_test(&t->send_pending))
912			wake_up(&t->wait_send_pending);
913		smb_direct_disconnect_rdma_connection(t);
914	}
915	return ret;
916}
917
918static void smb_direct_send_ctx_init(struct smb_direct_transport *t,
919				     struct smb_direct_send_ctx *send_ctx,
920				     bool need_invalidate_rkey,
921				     unsigned int remote_key)
922{
923	INIT_LIST_HEAD(&send_ctx->msg_list);
924	send_ctx->wr_cnt = 0;
925	send_ctx->need_invalidate_rkey = need_invalidate_rkey;
926	send_ctx->remote_key = remote_key;
927}
928
929static int smb_direct_flush_send_list(struct smb_direct_transport *t,
930				      struct smb_direct_send_ctx *send_ctx,
931				      bool is_last)
932{
933	struct smb_direct_sendmsg *first, *last;
934	int ret;
935
936	if (list_empty(&send_ctx->msg_list))
937		return 0;
938
939	first = list_first_entry(&send_ctx->msg_list,
940				 struct smb_direct_sendmsg,
941				 list);
942	last = list_last_entry(&send_ctx->msg_list,
943			       struct smb_direct_sendmsg,
944			       list);
945
946	last->wr.send_flags = IB_SEND_SIGNALED;
947	last->wr.wr_cqe = &last->cqe;
948	if (is_last && send_ctx->need_invalidate_rkey) {
949		last->wr.opcode = IB_WR_SEND_WITH_INV;
950		last->wr.ex.invalidate_rkey = send_ctx->remote_key;
951	}
952
953	ret = smb_direct_post_send(t, &first->wr);
954	if (!ret) {
955		smb_direct_send_ctx_init(t, send_ctx,
956					 send_ctx->need_invalidate_rkey,
957					 send_ctx->remote_key);
958	} else {
959		atomic_add(send_ctx->wr_cnt, &t->send_credits);
960		wake_up(&t->wait_send_credits);
961		list_for_each_entry_safe(first, last, &send_ctx->msg_list,
962					 list) {
963			smb_direct_free_sendmsg(t, first);
964		}
965	}
966	return ret;
967}
968
969static int wait_for_credits(struct smb_direct_transport *t,
970			    wait_queue_head_t *waitq, atomic_t *total_credits,
971			    int needed)
972{
973	int ret;
974
975	do {
976		if (atomic_sub_return(needed, total_credits) >= 0)
977			return 0;
978
979		atomic_add(needed, total_credits);
980		ret = wait_event_interruptible(*waitq,
981					       atomic_read(total_credits) >= needed ||
982					       t->status != SMB_DIRECT_CS_CONNECTED);
983
984		if (t->status != SMB_DIRECT_CS_CONNECTED)
985			return -ENOTCONN;
986		else if (ret < 0)
987			return ret;
988	} while (true);
989}
990
991static int wait_for_send_credits(struct smb_direct_transport *t,
992				 struct smb_direct_send_ctx *send_ctx)
993{
994	int ret;
995
996	if (send_ctx &&
997	    (send_ctx->wr_cnt >= 16 || atomic_read(&t->send_credits) <= 1)) {
998		ret = smb_direct_flush_send_list(t, send_ctx, false);
999		if (ret)
1000			return ret;
1001	}
1002
1003	return wait_for_credits(t, &t->wait_send_credits, &t->send_credits, 1);
1004}
1005
1006static int wait_for_rw_credits(struct smb_direct_transport *t, int credits)
1007{
1008	return wait_for_credits(t, &t->wait_rw_credits, &t->rw_credits, credits);
1009}
1010
1011static int calc_rw_credits(struct smb_direct_transport *t,
1012			   char *buf, unsigned int len)
1013{
1014	return DIV_ROUND_UP(get_buf_page_count(buf, len),
1015			    t->pages_per_rw_credit);
1016}
1017
1018static int smb_direct_create_header(struct smb_direct_transport *t,
1019				    int size, int remaining_data_length,
1020				    struct smb_direct_sendmsg **sendmsg_out)
1021{
1022	struct smb_direct_sendmsg *sendmsg;
1023	struct smb_direct_data_transfer *packet;
1024	int header_length;
1025	int ret;
1026
1027	sendmsg = smb_direct_alloc_sendmsg(t);
1028	if (IS_ERR(sendmsg))
1029		return PTR_ERR(sendmsg);
1030
1031	/* Fill in the packet header */
1032	packet = (struct smb_direct_data_transfer *)sendmsg->packet;
1033	packet->credits_requested = cpu_to_le16(t->send_credit_target);
1034	packet->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1035
1036	packet->flags = 0;
1037	packet->reserved = 0;
1038	if (!size)
1039		packet->data_offset = 0;
1040	else
1041		packet->data_offset = cpu_to_le32(24);
1042	packet->data_length = cpu_to_le32(size);
1043	packet->remaining_data_length = cpu_to_le32(remaining_data_length);
1044	packet->padding = 0;
1045
1046	ksmbd_debug(RDMA,
1047		    "credits_requested=%d credits_granted=%d data_offset=%d data_length=%d remaining_data_length=%d\n",
1048		    le16_to_cpu(packet->credits_requested),
1049		    le16_to_cpu(packet->credits_granted),
1050		    le32_to_cpu(packet->data_offset),
1051		    le32_to_cpu(packet->data_length),
1052		    le32_to_cpu(packet->remaining_data_length));
1053
1054	/* Map the packet to DMA */
1055	header_length = sizeof(struct smb_direct_data_transfer);
1056	/* If this is a packet without payload, don't send padding */
1057	if (!size)
1058		header_length =
1059			offsetof(struct smb_direct_data_transfer, padding);
1060
1061	sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1062						 (void *)packet,
1063						 header_length,
1064						 DMA_TO_DEVICE);
1065	ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1066	if (ret) {
1067		smb_direct_free_sendmsg(t, sendmsg);
1068		return ret;
1069	}
1070
1071	sendmsg->num_sge = 1;
1072	sendmsg->sge[0].length = header_length;
1073	sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1074
1075	*sendmsg_out = sendmsg;
1076	return 0;
1077}
1078
1079static int get_sg_list(void *buf, int size, struct scatterlist *sg_list, int nentries)
1080{
1081	bool high = is_vmalloc_addr(buf);
1082	struct page *page;
1083	int offset, len;
1084	int i = 0;
1085
1086	if (size <= 0 || nentries < get_buf_page_count(buf, size))
1087		return -EINVAL;
1088
1089	offset = offset_in_page(buf);
1090	buf -= offset;
1091	while (size > 0) {
1092		len = min_t(int, PAGE_SIZE - offset, size);
1093		if (high)
1094			page = vmalloc_to_page(buf);
1095		else
1096			page = kmap_to_page(buf);
1097
1098		if (!sg_list)
1099			return -EINVAL;
1100		sg_set_page(sg_list, page, len, offset);
1101		sg_list = sg_next(sg_list);
1102
1103		buf += PAGE_SIZE;
1104		size -= len;
1105		offset = 0;
1106		i++;
1107	}
1108	return i;
1109}
1110
1111static int get_mapped_sg_list(struct ib_device *device, void *buf, int size,
1112			      struct scatterlist *sg_list, int nentries,
1113			      enum dma_data_direction dir)
1114{
1115	int npages;
1116
1117	npages = get_sg_list(buf, size, sg_list, nentries);
1118	if (npages < 0)
1119		return -EINVAL;
1120	return ib_dma_map_sg(device, sg_list, npages, dir);
1121}
1122
1123static int post_sendmsg(struct smb_direct_transport *t,
1124			struct smb_direct_send_ctx *send_ctx,
1125			struct smb_direct_sendmsg *msg)
1126{
1127	int i;
1128
1129	for (i = 0; i < msg->num_sge; i++)
1130		ib_dma_sync_single_for_device(t->cm_id->device,
1131					      msg->sge[i].addr, msg->sge[i].length,
1132					      DMA_TO_DEVICE);
1133
1134	msg->cqe.done = send_done;
1135	msg->wr.opcode = IB_WR_SEND;
1136	msg->wr.sg_list = &msg->sge[0];
1137	msg->wr.num_sge = msg->num_sge;
1138	msg->wr.next = NULL;
1139
1140	if (send_ctx) {
1141		msg->wr.wr_cqe = NULL;
1142		msg->wr.send_flags = 0;
1143		if (!list_empty(&send_ctx->msg_list)) {
1144			struct smb_direct_sendmsg *last;
1145
1146			last = list_last_entry(&send_ctx->msg_list,
1147					       struct smb_direct_sendmsg,
1148					       list);
1149			last->wr.next = &msg->wr;
1150		}
1151		list_add_tail(&msg->list, &send_ctx->msg_list);
1152		send_ctx->wr_cnt++;
1153		return 0;
1154	}
1155
1156	msg->wr.wr_cqe = &msg->cqe;
1157	msg->wr.send_flags = IB_SEND_SIGNALED;
1158	return smb_direct_post_send(t, &msg->wr);
1159}
1160
1161static int smb_direct_post_send_data(struct smb_direct_transport *t,
1162				     struct smb_direct_send_ctx *send_ctx,
1163				     struct kvec *iov, int niov,
1164				     int remaining_data_length)
1165{
1166	int i, j, ret;
1167	struct smb_direct_sendmsg *msg;
1168	int data_length;
1169	struct scatterlist sg[SMB_DIRECT_MAX_SEND_SGES - 1];
1170
1171	ret = wait_for_send_credits(t, send_ctx);
1172	if (ret)
1173		return ret;
1174
1175	data_length = 0;
1176	for (i = 0; i < niov; i++)
1177		data_length += iov[i].iov_len;
1178
1179	ret = smb_direct_create_header(t, data_length, remaining_data_length,
1180				       &msg);
1181	if (ret) {
1182		atomic_inc(&t->send_credits);
1183		return ret;
1184	}
1185
1186	for (i = 0; i < niov; i++) {
1187		struct ib_sge *sge;
1188		int sg_cnt;
1189
1190		sg_init_table(sg, SMB_DIRECT_MAX_SEND_SGES - 1);
1191		sg_cnt = get_mapped_sg_list(t->cm_id->device,
1192					    iov[i].iov_base, iov[i].iov_len,
1193					    sg, SMB_DIRECT_MAX_SEND_SGES - 1,
1194					    DMA_TO_DEVICE);
1195		if (sg_cnt <= 0) {
1196			pr_err("failed to map buffer\n");
1197			ret = -ENOMEM;
1198			goto err;
1199		} else if (sg_cnt + msg->num_sge > SMB_DIRECT_MAX_SEND_SGES) {
1200			pr_err("buffer not fitted into sges\n");
1201			ret = -E2BIG;
1202			ib_dma_unmap_sg(t->cm_id->device, sg, sg_cnt,
1203					DMA_TO_DEVICE);
1204			goto err;
1205		}
1206
1207		for (j = 0; j < sg_cnt; j++) {
1208			sge = &msg->sge[msg->num_sge];
1209			sge->addr = sg_dma_address(&sg[j]);
1210			sge->length = sg_dma_len(&sg[j]);
1211			sge->lkey  = t->pd->local_dma_lkey;
1212			msg->num_sge++;
1213		}
1214	}
1215
1216	ret = post_sendmsg(t, send_ctx, msg);
1217	if (ret)
1218		goto err;
1219	return 0;
1220err:
1221	smb_direct_free_sendmsg(t, msg);
1222	atomic_inc(&t->send_credits);
1223	return ret;
1224}
1225
1226static int smb_direct_writev(struct ksmbd_transport *t,
1227			     struct kvec *iov, int niovs, int buflen,
1228			     bool need_invalidate, unsigned int remote_key)
1229{
1230	struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1231	int remaining_data_length;
1232	int start, i, j;
1233	int max_iov_size = st->max_send_size -
1234			sizeof(struct smb_direct_data_transfer);
1235	int ret;
1236	struct kvec vec;
1237	struct smb_direct_send_ctx send_ctx;
1238
1239	if (st->status != SMB_DIRECT_CS_CONNECTED)
1240		return -ENOTCONN;
1241
1242	//FIXME: skip RFC1002 header..
1243	buflen -= 4;
1244
1245	remaining_data_length = buflen;
1246	ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
1247
1248	smb_direct_send_ctx_init(st, &send_ctx, need_invalidate, remote_key);
1249	start = i = 1;
1250	buflen = 0;
1251	while (true) {
1252		buflen += iov[i].iov_len;
1253		if (buflen > max_iov_size) {
1254			if (i > start) {
1255				remaining_data_length -=
1256					(buflen - iov[i].iov_len);
1257				ret = smb_direct_post_send_data(st, &send_ctx,
1258								&iov[start], i - start,
1259								remaining_data_length);
1260				if (ret)
1261					goto done;
1262			} else {
1263				/* iov[start] is too big, break it */
1264				int nvec  = (buflen + max_iov_size - 1) /
1265						max_iov_size;
1266
1267				for (j = 0; j < nvec; j++) {
1268					vec.iov_base =
1269						(char *)iov[start].iov_base +
1270						j * max_iov_size;
1271					vec.iov_len =
1272						min_t(int, max_iov_size,
1273						      buflen - max_iov_size * j);
1274					remaining_data_length -= vec.iov_len;
1275					ret = smb_direct_post_send_data(st, &send_ctx, &vec, 1,
1276									remaining_data_length);
1277					if (ret)
1278						goto done;
1279				}
1280				i++;
1281				if (i == niovs)
1282					break;
1283			}
1284			start = i;
1285			buflen = 0;
1286		} else {
1287			i++;
1288			if (i == niovs) {
1289				/* send out all remaining vecs */
1290				remaining_data_length -= buflen;
1291				ret = smb_direct_post_send_data(st, &send_ctx,
1292								&iov[start], i - start,
1293								remaining_data_length);
1294				if (ret)
1295					goto done;
1296				break;
1297			}
1298		}
1299	}
1300
1301done:
1302	ret = smb_direct_flush_send_list(st, &send_ctx, true);
1303
1304	/*
1305	 * As an optimization, we don't wait for individual I/O to finish
1306	 * before sending the next one.
1307	 * Send them all and wait for pending send count to get to 0
1308	 * that means all the I/Os have been out and we are good to return
1309	 */
1310
1311	wait_event(st->wait_send_pending,
1312		   atomic_read(&st->send_pending) == 0);
1313	return ret;
1314}
1315
1316static void smb_direct_free_rdma_rw_msg(struct smb_direct_transport *t,
1317					struct smb_direct_rdma_rw_msg *msg,
1318					enum dma_data_direction dir)
1319{
1320	rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
1321			    msg->sgt.sgl, msg->sgt.nents, dir);
1322	sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1323	kfree(msg);
1324}
1325
1326static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
1327			    enum dma_data_direction dir)
1328{
1329	struct smb_direct_rdma_rw_msg *msg = container_of(wc->wr_cqe,
1330							  struct smb_direct_rdma_rw_msg, cqe);
1331	struct smb_direct_transport *t = msg->t;
1332
1333	if (wc->status != IB_WC_SUCCESS) {
1334		msg->status = -EIO;
1335		pr_err("read/write error. opcode = %d, status = %s(%d)\n",
1336		       wc->opcode, ib_wc_status_msg(wc->status), wc->status);
1337		if (wc->status != IB_WC_WR_FLUSH_ERR)
1338			smb_direct_disconnect_rdma_connection(t);
1339	}
1340
1341	complete(msg->completion);
1342}
1343
1344static void read_done(struct ib_cq *cq, struct ib_wc *wc)
1345{
1346	read_write_done(cq, wc, DMA_FROM_DEVICE);
1347}
1348
1349static void write_done(struct ib_cq *cq, struct ib_wc *wc)
1350{
1351	read_write_done(cq, wc, DMA_TO_DEVICE);
1352}
1353
1354static int smb_direct_rdma_xmit(struct smb_direct_transport *t,
1355				void *buf, int buf_len,
1356				struct smb2_buffer_desc_v1 *desc,
1357				unsigned int desc_len,
1358				bool is_read)
1359{
1360	struct smb_direct_rdma_rw_msg *msg, *next_msg;
1361	int i, ret;
1362	DECLARE_COMPLETION_ONSTACK(completion);
1363	struct ib_send_wr *first_wr;
1364	LIST_HEAD(msg_list);
1365	char *desc_buf;
1366	int credits_needed;
1367	unsigned int desc_buf_len, desc_num = 0;
1368
1369	if (t->status != SMB_DIRECT_CS_CONNECTED)
1370		return -ENOTCONN;
1371
1372	if (buf_len > t->max_rdma_rw_size)
1373		return -EINVAL;
1374
1375	/* calculate needed credits */
1376	credits_needed = 0;
1377	desc_buf = buf;
1378	for (i = 0; i < desc_len / sizeof(*desc); i++) {
1379		if (!buf_len)
1380			break;
1381
1382		desc_buf_len = le32_to_cpu(desc[i].length);
1383		if (!desc_buf_len)
1384			return -EINVAL;
1385
1386		if (desc_buf_len > buf_len) {
1387			desc_buf_len = buf_len;
1388			desc[i].length = cpu_to_le32(desc_buf_len);
1389			buf_len = 0;
1390		}
1391
1392		credits_needed += calc_rw_credits(t, desc_buf, desc_buf_len);
1393		desc_buf += desc_buf_len;
1394		buf_len -= desc_buf_len;
1395		desc_num++;
1396	}
1397
1398	ksmbd_debug(RDMA, "RDMA %s, len %#x, needed credits %#x\n",
1399		    is_read ? "read" : "write", buf_len, credits_needed);
1400
1401	ret = wait_for_rw_credits(t, credits_needed);
1402	if (ret < 0)
1403		return ret;
1404
1405	/* build rdma_rw_ctx for each descriptor */
1406	desc_buf = buf;
1407	for (i = 0; i < desc_num; i++) {
1408		msg = kzalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
1409			      sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
1410		if (!msg) {
1411			ret = -ENOMEM;
1412			goto out;
1413		}
1414
1415		desc_buf_len = le32_to_cpu(desc[i].length);
1416
1417		msg->t = t;
1418		msg->cqe.done = is_read ? read_done : write_done;
1419		msg->completion = &completion;
1420
1421		msg->sgt.sgl = &msg->sg_list[0];
1422		ret = sg_alloc_table_chained(&msg->sgt,
1423					     get_buf_page_count(desc_buf, desc_buf_len),
1424					     msg->sg_list, SG_CHUNK_SIZE);
1425		if (ret) {
1426			kfree(msg);
1427			ret = -ENOMEM;
1428			goto out;
1429		}
1430
1431		ret = get_sg_list(desc_buf, desc_buf_len,
1432				  msg->sgt.sgl, msg->sgt.orig_nents);
1433		if (ret < 0) {
1434			sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1435			kfree(msg);
1436			goto out;
1437		}
1438
1439		ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
1440				       msg->sgt.sgl,
1441				       get_buf_page_count(desc_buf, desc_buf_len),
1442				       0,
1443				       le64_to_cpu(desc[i].offset),
1444				       le32_to_cpu(desc[i].token),
1445				       is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1446		if (ret < 0) {
1447			pr_err("failed to init rdma_rw_ctx: %d\n", ret);
1448			sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1449			kfree(msg);
1450			goto out;
1451		}
1452
1453		list_add_tail(&msg->list, &msg_list);
1454		desc_buf += desc_buf_len;
1455	}
1456
1457	/* concatenate work requests of rdma_rw_ctxs */
1458	first_wr = NULL;
1459	list_for_each_entry_reverse(msg, &msg_list, list) {
1460		first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
1461					   &msg->cqe, first_wr);
1462	}
1463
1464	ret = ib_post_send(t->qp, first_wr, NULL);
1465	if (ret) {
1466		pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
1467		goto out;
1468	}
1469
1470	msg = list_last_entry(&msg_list, struct smb_direct_rdma_rw_msg, list);
1471	wait_for_completion(&completion);
1472	ret = msg->status;
1473out:
1474	list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
1475		list_del(&msg->list);
1476		smb_direct_free_rdma_rw_msg(t, msg,
1477					    is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1478	}
1479	atomic_add(credits_needed, &t->rw_credits);
1480	wake_up(&t->wait_rw_credits);
1481	return ret;
1482}
1483
1484static int smb_direct_rdma_write(struct ksmbd_transport *t,
1485				 void *buf, unsigned int buflen,
1486				 struct smb2_buffer_desc_v1 *desc,
1487				 unsigned int desc_len)
1488{
1489	return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1490				    desc, desc_len, false);
1491}
1492
1493static int smb_direct_rdma_read(struct ksmbd_transport *t,
1494				void *buf, unsigned int buflen,
1495				struct smb2_buffer_desc_v1 *desc,
1496				unsigned int desc_len)
1497{
1498	return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1499				    desc, desc_len, true);
1500}
1501
1502static void smb_direct_disconnect(struct ksmbd_transport *t)
1503{
1504	struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1505
1506	ksmbd_debug(RDMA, "Disconnecting cm_id=%p\n", st->cm_id);
1507
1508	smb_direct_disconnect_rdma_work(&st->disconnect_work);
1509	wait_event_interruptible(st->wait_status,
1510				 st->status == SMB_DIRECT_CS_DISCONNECTED);
1511	free_transport(st);
1512}
1513
1514static void smb_direct_shutdown(struct ksmbd_transport *t)
1515{
1516	struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1517
1518	ksmbd_debug(RDMA, "smb-direct shutdown cm_id=%p\n", st->cm_id);
1519
1520	smb_direct_disconnect_rdma_work(&st->disconnect_work);
1521}
1522
1523static int smb_direct_cm_handler(struct rdma_cm_id *cm_id,
1524				 struct rdma_cm_event *event)
1525{
1526	struct smb_direct_transport *t = cm_id->context;
1527
1528	ksmbd_debug(RDMA, "RDMA CM event. cm_id=%p event=%s (%d)\n",
1529		    cm_id, rdma_event_msg(event->event), event->event);
1530
1531	switch (event->event) {
1532	case RDMA_CM_EVENT_ESTABLISHED: {
1533		t->status = SMB_DIRECT_CS_CONNECTED;
1534		wake_up_interruptible(&t->wait_status);
1535		break;
1536	}
1537	case RDMA_CM_EVENT_DEVICE_REMOVAL:
1538	case RDMA_CM_EVENT_DISCONNECTED: {
1539		ib_drain_qp(t->qp);
1540
1541		t->status = SMB_DIRECT_CS_DISCONNECTED;
1542		wake_up_interruptible(&t->wait_status);
1543		wake_up_interruptible(&t->wait_reassembly_queue);
1544		wake_up(&t->wait_send_credits);
1545		break;
1546	}
1547	case RDMA_CM_EVENT_CONNECT_ERROR: {
1548		t->status = SMB_DIRECT_CS_DISCONNECTED;
1549		wake_up_interruptible(&t->wait_status);
1550		break;
1551	}
1552	default:
1553		pr_err("Unexpected RDMA CM event. cm_id=%p, event=%s (%d)\n",
1554		       cm_id, rdma_event_msg(event->event),
1555		       event->event);
1556		break;
1557	}
1558	return 0;
1559}
1560
1561static void smb_direct_qpair_handler(struct ib_event *event, void *context)
1562{
1563	struct smb_direct_transport *t = context;
1564
1565	ksmbd_debug(RDMA, "Received QP event. cm_id=%p, event=%s (%d)\n",
1566		    t->cm_id, ib_event_msg(event->event), event->event);
1567
1568	switch (event->event) {
1569	case IB_EVENT_CQ_ERR:
1570	case IB_EVENT_QP_FATAL:
1571		smb_direct_disconnect_rdma_connection(t);
1572		break;
1573	default:
1574		break;
1575	}
1576}
1577
1578static int smb_direct_send_negotiate_response(struct smb_direct_transport *t,
1579					      int failed)
1580{
1581	struct smb_direct_sendmsg *sendmsg;
1582	struct smb_direct_negotiate_resp *resp;
1583	int ret;
1584
1585	sendmsg = smb_direct_alloc_sendmsg(t);
1586	if (IS_ERR(sendmsg))
1587		return -ENOMEM;
1588
1589	resp = (struct smb_direct_negotiate_resp *)sendmsg->packet;
1590	if (failed) {
1591		memset(resp, 0, sizeof(*resp));
1592		resp->min_version = cpu_to_le16(0x0100);
1593		resp->max_version = cpu_to_le16(0x0100);
1594		resp->status = STATUS_NOT_SUPPORTED;
1595	} else {
1596		resp->status = STATUS_SUCCESS;
1597		resp->min_version = SMB_DIRECT_VERSION_LE;
1598		resp->max_version = SMB_DIRECT_VERSION_LE;
1599		resp->negotiated_version = SMB_DIRECT_VERSION_LE;
1600		resp->reserved = 0;
1601		resp->credits_requested =
1602				cpu_to_le16(t->send_credit_target);
1603		resp->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1604		resp->max_readwrite_size = cpu_to_le32(t->max_rdma_rw_size);
1605		resp->preferred_send_size = cpu_to_le32(t->max_send_size);
1606		resp->max_receive_size = cpu_to_le32(t->max_recv_size);
1607		resp->max_fragmented_size =
1608				cpu_to_le32(t->max_fragmented_recv_size);
1609	}
1610
1611	sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1612						 (void *)resp, sizeof(*resp),
1613						 DMA_TO_DEVICE);
1614	ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1615	if (ret) {
1616		smb_direct_free_sendmsg(t, sendmsg);
1617		return ret;
1618	}
1619
1620	sendmsg->num_sge = 1;
1621	sendmsg->sge[0].length = sizeof(*resp);
1622	sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1623
1624	ret = post_sendmsg(t, NULL, sendmsg);
1625	if (ret) {
1626		smb_direct_free_sendmsg(t, sendmsg);
1627		return ret;
1628	}
1629
1630	wait_event(t->wait_send_pending,
1631		   atomic_read(&t->send_pending) == 0);
1632	return 0;
1633}
1634
1635static int smb_direct_accept_client(struct smb_direct_transport *t)
1636{
1637	struct rdma_conn_param conn_param;
1638	struct ib_port_immutable port_immutable;
1639	u32 ird_ord_hdr[2];
1640	int ret;
1641
1642	memset(&conn_param, 0, sizeof(conn_param));
1643	conn_param.initiator_depth = min_t(u8, t->cm_id->device->attrs.max_qp_rd_atom,
1644					   SMB_DIRECT_CM_INITIATOR_DEPTH);
1645	conn_param.responder_resources = 0;
1646
1647	t->cm_id->device->ops.get_port_immutable(t->cm_id->device,
1648						 t->cm_id->port_num,
1649						 &port_immutable);
1650	if (port_immutable.core_cap_flags & RDMA_CORE_PORT_IWARP) {
1651		ird_ord_hdr[0] = conn_param.responder_resources;
1652		ird_ord_hdr[1] = 1;
1653		conn_param.private_data = ird_ord_hdr;
1654		conn_param.private_data_len = sizeof(ird_ord_hdr);
1655	} else {
1656		conn_param.private_data = NULL;
1657		conn_param.private_data_len = 0;
1658	}
1659	conn_param.retry_count = SMB_DIRECT_CM_RETRY;
1660	conn_param.rnr_retry_count = SMB_DIRECT_CM_RNR_RETRY;
1661	conn_param.flow_control = 0;
1662
1663	ret = rdma_accept(t->cm_id, &conn_param);
1664	if (ret) {
1665		pr_err("error at rdma_accept: %d\n", ret);
1666		return ret;
1667	}
1668	return 0;
1669}
1670
1671static int smb_direct_prepare_negotiation(struct smb_direct_transport *t)
1672{
1673	int ret;
1674	struct smb_direct_recvmsg *recvmsg;
1675
1676	recvmsg = get_free_recvmsg(t);
1677	if (!recvmsg)
1678		return -ENOMEM;
1679	recvmsg->type = SMB_DIRECT_MSG_NEGOTIATE_REQ;
1680
1681	ret = smb_direct_post_recv(t, recvmsg);
1682	if (ret) {
1683		pr_err("Can't post recv: %d\n", ret);
1684		goto out_err;
1685	}
1686
1687	t->negotiation_requested = false;
1688	ret = smb_direct_accept_client(t);
1689	if (ret) {
1690		pr_err("Can't accept client\n");
1691		goto out_err;
1692	}
1693
1694	smb_direct_post_recv_credits(&t->post_recv_credits_work.work);
1695	return 0;
1696out_err:
1697	put_recvmsg(t, recvmsg);
1698	return ret;
1699}
1700
1701static unsigned int smb_direct_get_max_fr_pages(struct smb_direct_transport *t)
1702{
1703	return min_t(unsigned int,
1704		     t->cm_id->device->attrs.max_fast_reg_page_list_len,
1705		     256);
1706}
1707
1708static int smb_direct_init_params(struct smb_direct_transport *t,
1709				  struct ib_qp_cap *cap)
1710{
1711	struct ib_device *device = t->cm_id->device;
1712	int max_send_sges, max_rw_wrs, max_send_wrs;
1713	unsigned int max_sge_per_wr, wrs_per_credit;
1714
1715	/* need 3 more sge. because a SMB_DIRECT header, SMB2 header,
1716	 * SMB2 response could be mapped.
1717	 */
1718	t->max_send_size = smb_direct_max_send_size;
1719	max_send_sges = DIV_ROUND_UP(t->max_send_size, PAGE_SIZE) + 3;
1720	if (max_send_sges > SMB_DIRECT_MAX_SEND_SGES) {
1721		pr_err("max_send_size %d is too large\n", t->max_send_size);
1722		return -EINVAL;
1723	}
1724
1725	/* Calculate the number of work requests for RDMA R/W.
1726	 * The maximum number of pages which can be registered
1727	 * with one Memory region can be transferred with one
1728	 * R/W credit. And at least 4 work requests for each credit
1729	 * are needed for MR registration, RDMA R/W, local & remote
1730	 * MR invalidation.
1731	 */
1732	t->max_rdma_rw_size = smb_direct_max_read_write_size;
1733	t->pages_per_rw_credit = smb_direct_get_max_fr_pages(t);
1734	t->max_rw_credits = DIV_ROUND_UP(t->max_rdma_rw_size,
1735					 (t->pages_per_rw_credit - 1) *
1736					 PAGE_SIZE);
1737
1738	max_sge_per_wr = min_t(unsigned int, device->attrs.max_send_sge,
1739			       device->attrs.max_sge_rd);
1740	max_sge_per_wr = max_t(unsigned int, max_sge_per_wr,
1741			       max_send_sges);
1742	wrs_per_credit = max_t(unsigned int, 4,
1743			       DIV_ROUND_UP(t->pages_per_rw_credit,
1744					    max_sge_per_wr) + 1);
1745	max_rw_wrs = t->max_rw_credits * wrs_per_credit;
1746
1747	max_send_wrs = smb_direct_send_credit_target + max_rw_wrs;
1748	if (max_send_wrs > device->attrs.max_cqe ||
1749	    max_send_wrs > device->attrs.max_qp_wr) {
1750		pr_err("consider lowering send_credit_target = %d\n",
1751		       smb_direct_send_credit_target);
1752		pr_err("Possible CQE overrun, device reporting max_cqe %d max_qp_wr %d\n",
1753		       device->attrs.max_cqe, device->attrs.max_qp_wr);
1754		return -EINVAL;
1755	}
1756
1757	if (smb_direct_receive_credit_max > device->attrs.max_cqe ||
1758	    smb_direct_receive_credit_max > device->attrs.max_qp_wr) {
1759		pr_err("consider lowering receive_credit_max = %d\n",
1760		       smb_direct_receive_credit_max);
1761		pr_err("Possible CQE overrun, device reporting max_cpe %d max_qp_wr %d\n",
1762		       device->attrs.max_cqe, device->attrs.max_qp_wr);
1763		return -EINVAL;
1764	}
1765
1766	if (device->attrs.max_recv_sge < SMB_DIRECT_MAX_RECV_SGES) {
1767		pr_err("warning: device max_recv_sge = %d too small\n",
1768		       device->attrs.max_recv_sge);
1769		return -EINVAL;
1770	}
1771
1772	t->recv_credits = 0;
1773	t->count_avail_recvmsg = 0;
1774
1775	t->recv_credit_max = smb_direct_receive_credit_max;
1776	t->recv_credit_target = 10;
1777	t->new_recv_credits = 0;
1778
1779	t->send_credit_target = smb_direct_send_credit_target;
1780	atomic_set(&t->send_credits, 0);
1781	atomic_set(&t->rw_credits, t->max_rw_credits);
1782
1783	t->max_send_size = smb_direct_max_send_size;
1784	t->max_recv_size = smb_direct_max_receive_size;
1785	t->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size;
1786
1787	cap->max_send_wr = max_send_wrs;
1788	cap->max_recv_wr = t->recv_credit_max;
1789	cap->max_send_sge = max_sge_per_wr;
1790	cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
1791	cap->max_inline_data = 0;
1792	cap->max_rdma_ctxs = t->max_rw_credits;
1793	return 0;
1794}
1795
1796static void smb_direct_destroy_pools(struct smb_direct_transport *t)
1797{
1798	struct smb_direct_recvmsg *recvmsg;
1799
1800	while ((recvmsg = get_free_recvmsg(t)))
1801		mempool_free(recvmsg, t->recvmsg_mempool);
1802	while ((recvmsg = get_empty_recvmsg(t)))
1803		mempool_free(recvmsg, t->recvmsg_mempool);
1804
1805	mempool_destroy(t->recvmsg_mempool);
1806	t->recvmsg_mempool = NULL;
1807
1808	kmem_cache_destroy(t->recvmsg_cache);
1809	t->recvmsg_cache = NULL;
1810
1811	mempool_destroy(t->sendmsg_mempool);
1812	t->sendmsg_mempool = NULL;
1813
1814	kmem_cache_destroy(t->sendmsg_cache);
1815	t->sendmsg_cache = NULL;
1816}
1817
1818static int smb_direct_create_pools(struct smb_direct_transport *t)
1819{
1820	char name[80];
1821	int i;
1822	struct smb_direct_recvmsg *recvmsg;
1823
1824	snprintf(name, sizeof(name), "smb_direct_rqst_pool_%p", t);
1825	t->sendmsg_cache = kmem_cache_create(name,
1826					     sizeof(struct smb_direct_sendmsg) +
1827					      sizeof(struct smb_direct_negotiate_resp),
1828					     0, SLAB_HWCACHE_ALIGN, NULL);
1829	if (!t->sendmsg_cache)
1830		return -ENOMEM;
1831
1832	t->sendmsg_mempool = mempool_create(t->send_credit_target,
1833					    mempool_alloc_slab, mempool_free_slab,
1834					    t->sendmsg_cache);
1835	if (!t->sendmsg_mempool)
1836		goto err;
1837
1838	snprintf(name, sizeof(name), "smb_direct_resp_%p", t);
1839	t->recvmsg_cache = kmem_cache_create(name,
1840					     sizeof(struct smb_direct_recvmsg) +
1841					      t->max_recv_size,
1842					     0, SLAB_HWCACHE_ALIGN, NULL);
1843	if (!t->recvmsg_cache)
1844		goto err;
1845
1846	t->recvmsg_mempool =
1847		mempool_create(t->recv_credit_max, mempool_alloc_slab,
1848			       mempool_free_slab, t->recvmsg_cache);
1849	if (!t->recvmsg_mempool)
1850		goto err;
1851
1852	INIT_LIST_HEAD(&t->recvmsg_queue);
1853
1854	for (i = 0; i < t->recv_credit_max; i++) {
1855		recvmsg = mempool_alloc(t->recvmsg_mempool, GFP_KERNEL);
1856		if (!recvmsg)
1857			goto err;
1858		recvmsg->transport = t;
1859		list_add(&recvmsg->list, &t->recvmsg_queue);
1860	}
1861	t->count_avail_recvmsg = t->recv_credit_max;
1862
1863	return 0;
1864err:
1865	smb_direct_destroy_pools(t);
1866	return -ENOMEM;
1867}
1868
1869static int smb_direct_create_qpair(struct smb_direct_transport *t,
1870				   struct ib_qp_cap *cap)
1871{
1872	int ret;
1873	struct ib_qp_init_attr qp_attr;
1874	int pages_per_rw;
1875
1876	t->pd = ib_alloc_pd(t->cm_id->device, 0);
1877	if (IS_ERR(t->pd)) {
1878		pr_err("Can't create RDMA PD\n");
1879		ret = PTR_ERR(t->pd);
1880		t->pd = NULL;
1881		return ret;
1882	}
1883
1884	t->send_cq = ib_alloc_cq(t->cm_id->device, t,
1885				 smb_direct_send_credit_target + cap->max_rdma_ctxs,
1886				 0, IB_POLL_WORKQUEUE);
1887	if (IS_ERR(t->send_cq)) {
1888		pr_err("Can't create RDMA send CQ\n");
1889		ret = PTR_ERR(t->send_cq);
1890		t->send_cq = NULL;
1891		goto err;
1892	}
1893
1894	t->recv_cq = ib_alloc_cq(t->cm_id->device, t,
1895				 t->recv_credit_max, 0, IB_POLL_WORKQUEUE);
1896	if (IS_ERR(t->recv_cq)) {
1897		pr_err("Can't create RDMA recv CQ\n");
1898		ret = PTR_ERR(t->recv_cq);
1899		t->recv_cq = NULL;
1900		goto err;
1901	}
1902
1903	memset(&qp_attr, 0, sizeof(qp_attr));
1904	qp_attr.event_handler = smb_direct_qpair_handler;
1905	qp_attr.qp_context = t;
1906	qp_attr.cap = *cap;
1907	qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
1908	qp_attr.qp_type = IB_QPT_RC;
1909	qp_attr.send_cq = t->send_cq;
1910	qp_attr.recv_cq = t->recv_cq;
1911	qp_attr.port_num = ~0;
1912
1913	ret = rdma_create_qp(t->cm_id, t->pd, &qp_attr);
1914	if (ret) {
1915		pr_err("Can't create RDMA QP: %d\n", ret);
1916		goto err;
1917	}
1918
1919	t->qp = t->cm_id->qp;
1920	t->cm_id->event_handler = smb_direct_cm_handler;
1921
1922	pages_per_rw = DIV_ROUND_UP(t->max_rdma_rw_size, PAGE_SIZE) + 1;
1923	if (pages_per_rw > t->cm_id->device->attrs.max_sgl_rd) {
1924		ret = ib_mr_pool_init(t->qp, &t->qp->rdma_mrs,
1925				      t->max_rw_credits, IB_MR_TYPE_MEM_REG,
1926				      t->pages_per_rw_credit, 0);
1927		if (ret) {
1928			pr_err("failed to init mr pool count %d pages %d\n",
1929			       t->max_rw_credits, t->pages_per_rw_credit);
1930			goto err;
1931		}
1932	}
1933
1934	return 0;
1935err:
1936	if (t->qp) {
1937		ib_destroy_qp(t->qp);
1938		t->qp = NULL;
1939	}
1940	if (t->recv_cq) {
1941		ib_destroy_cq(t->recv_cq);
1942		t->recv_cq = NULL;
1943	}
1944	if (t->send_cq) {
1945		ib_destroy_cq(t->send_cq);
1946		t->send_cq = NULL;
1947	}
1948	if (t->pd) {
1949		ib_dealloc_pd(t->pd);
1950		t->pd = NULL;
1951	}
1952	return ret;
1953}
1954
1955static int smb_direct_prepare(struct ksmbd_transport *t)
1956{
1957	struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1958	struct smb_direct_recvmsg *recvmsg;
1959	struct smb_direct_negotiate_req *req;
1960	int ret;
1961
1962	ksmbd_debug(RDMA, "Waiting for SMB_DIRECT negotiate request\n");
1963	ret = wait_event_interruptible_timeout(st->wait_status,
1964					       st->negotiation_requested ||
1965					       st->status == SMB_DIRECT_CS_DISCONNECTED,
1966					       SMB_DIRECT_NEGOTIATE_TIMEOUT * HZ);
1967	if (ret <= 0 || st->status == SMB_DIRECT_CS_DISCONNECTED)
1968		return ret < 0 ? ret : -ETIMEDOUT;
1969
1970	recvmsg = get_first_reassembly(st);
1971	if (!recvmsg)
1972		return -ECONNABORTED;
1973
1974	ret = smb_direct_check_recvmsg(recvmsg);
1975	if (ret == -ECONNABORTED)
1976		goto out;
1977
1978	req = (struct smb_direct_negotiate_req *)recvmsg->packet;
1979	st->max_recv_size = min_t(int, st->max_recv_size,
1980				  le32_to_cpu(req->preferred_send_size));
1981	st->max_send_size = min_t(int, st->max_send_size,
1982				  le32_to_cpu(req->max_receive_size));
1983	st->max_fragmented_send_size =
1984		le32_to_cpu(req->max_fragmented_size);
1985	st->max_fragmented_recv_size =
1986		(st->recv_credit_max * st->max_recv_size) / 2;
1987
1988	ret = smb_direct_send_negotiate_response(st, ret);
1989out:
1990	spin_lock_irq(&st->reassembly_queue_lock);
1991	st->reassembly_queue_length--;
1992	list_del(&recvmsg->list);
1993	spin_unlock_irq(&st->reassembly_queue_lock);
1994	put_recvmsg(st, recvmsg);
1995
1996	return ret;
1997}
1998
1999static int smb_direct_connect(struct smb_direct_transport *st)
2000{
2001	int ret;
2002	struct ib_qp_cap qp_cap;
2003
2004	ret = smb_direct_init_params(st, &qp_cap);
2005	if (ret) {
2006		pr_err("Can't configure RDMA parameters\n");
2007		return ret;
2008	}
2009
2010	ret = smb_direct_create_pools(st);
2011	if (ret) {
2012		pr_err("Can't init RDMA pool: %d\n", ret);
2013		return ret;
2014	}
2015
2016	ret = smb_direct_create_qpair(st, &qp_cap);
2017	if (ret) {
2018		pr_err("Can't accept RDMA client: %d\n", ret);
2019		return ret;
2020	}
2021
2022	ret = smb_direct_prepare_negotiation(st);
2023	if (ret) {
2024		pr_err("Can't negotiate: %d\n", ret);
2025		return ret;
2026	}
2027	return 0;
2028}
2029
2030static bool rdma_frwr_is_supported(struct ib_device_attr *attrs)
2031{
2032	if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
2033		return false;
2034	if (attrs->max_fast_reg_page_list_len == 0)
2035		return false;
2036	return true;
2037}
2038
2039static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id)
2040{
2041	struct smb_direct_transport *t;
2042	struct task_struct *handler;
2043	int ret;
2044
2045	if (!rdma_frwr_is_supported(&new_cm_id->device->attrs)) {
2046		ksmbd_debug(RDMA,
2047			    "Fast Registration Work Requests is not supported. device capabilities=%llx\n",
2048			    new_cm_id->device->attrs.device_cap_flags);
2049		return -EPROTONOSUPPORT;
2050	}
2051
2052	t = alloc_transport(new_cm_id);
2053	if (!t)
2054		return -ENOMEM;
2055
2056	ret = smb_direct_connect(t);
2057	if (ret)
2058		goto out_err;
2059
2060	handler = kthread_run(ksmbd_conn_handler_loop,
2061			      KSMBD_TRANS(t)->conn, "ksmbd:r%u",
2062			      smb_direct_port);
2063	if (IS_ERR(handler)) {
2064		ret = PTR_ERR(handler);
2065		pr_err("Can't start thread\n");
2066		goto out_err;
2067	}
2068
2069	return 0;
2070out_err:
2071	free_transport(t);
2072	return ret;
2073}
2074
2075static int smb_direct_listen_handler(struct rdma_cm_id *cm_id,
2076				     struct rdma_cm_event *event)
2077{
2078	switch (event->event) {
2079	case RDMA_CM_EVENT_CONNECT_REQUEST: {
2080		int ret = smb_direct_handle_connect_request(cm_id);
2081
2082		if (ret) {
2083			pr_err("Can't create transport: %d\n", ret);
2084			return ret;
2085		}
2086
2087		ksmbd_debug(RDMA, "Received connection request. cm_id=%p\n",
2088			    cm_id);
2089		break;
2090	}
2091	default:
2092		pr_err("Unexpected listen event. cm_id=%p, event=%s (%d)\n",
2093		       cm_id, rdma_event_msg(event->event), event->event);
2094		break;
2095	}
2096	return 0;
2097}
2098
2099static int smb_direct_listen(int port)
2100{
2101	int ret;
2102	struct rdma_cm_id *cm_id;
2103	struct sockaddr_in sin = {
2104		.sin_family		= AF_INET,
2105		.sin_addr.s_addr	= htonl(INADDR_ANY),
2106		.sin_port		= htons(port),
2107	};
2108
2109	cm_id = rdma_create_id(&init_net, smb_direct_listen_handler,
2110			       &smb_direct_listener, RDMA_PS_TCP, IB_QPT_RC);
2111	if (IS_ERR(cm_id)) {
2112		pr_err("Can't create cm id: %ld\n", PTR_ERR(cm_id));
2113		return PTR_ERR(cm_id);
2114	}
2115
2116	ret = rdma_bind_addr(cm_id, (struct sockaddr *)&sin);
2117	if (ret) {
2118		pr_err("Can't bind: %d\n", ret);
2119		goto err;
2120	}
2121
2122	smb_direct_listener.cm_id = cm_id;
2123
2124	ret = rdma_listen(cm_id, 10);
2125	if (ret) {
2126		pr_err("Can't listen: %d\n", ret);
2127		goto err;
2128	}
2129	return 0;
2130err:
2131	smb_direct_listener.cm_id = NULL;
2132	rdma_destroy_id(cm_id);
2133	return ret;
2134}
2135
2136static int smb_direct_ib_client_add(struct ib_device *ib_dev)
2137{
2138	struct smb_direct_device *smb_dev;
2139
2140	/* Set 5445 port if device type is iWARP(No IB) */
2141	if (ib_dev->node_type != RDMA_NODE_IB_CA)
2142		smb_direct_port = SMB_DIRECT_PORT_IWARP;
2143
2144	if (!rdma_frwr_is_supported(&ib_dev->attrs))
2145		return 0;
2146
2147	smb_dev = kzalloc(sizeof(*smb_dev), GFP_KERNEL);
2148	if (!smb_dev)
2149		return -ENOMEM;
2150	smb_dev->ib_dev = ib_dev;
2151
2152	write_lock(&smb_direct_device_lock);
2153	list_add(&smb_dev->list, &smb_direct_device_list);
2154	write_unlock(&smb_direct_device_lock);
2155
2156	ksmbd_debug(RDMA, "ib device added: name %s\n", ib_dev->name);
2157	return 0;
2158}
2159
2160static void smb_direct_ib_client_remove(struct ib_device *ib_dev,
2161					void *client_data)
2162{
2163	struct smb_direct_device *smb_dev, *tmp;
2164
2165	write_lock(&smb_direct_device_lock);
2166	list_for_each_entry_safe(smb_dev, tmp, &smb_direct_device_list, list) {
2167		if (smb_dev->ib_dev == ib_dev) {
2168			list_del(&smb_dev->list);
2169			kfree(smb_dev);
2170			break;
2171		}
2172	}
2173	write_unlock(&smb_direct_device_lock);
2174}
2175
2176static struct ib_client smb_direct_ib_client = {
2177	.name	= "ksmbd_smb_direct_ib",
2178	.add	= smb_direct_ib_client_add,
2179	.remove	= smb_direct_ib_client_remove,
2180};
2181
2182int ksmbd_rdma_init(void)
2183{
2184	int ret;
2185
2186	smb_direct_listener.cm_id = NULL;
2187
2188	ret = ib_register_client(&smb_direct_ib_client);
2189	if (ret) {
2190		pr_err("failed to ib_register_client\n");
2191		return ret;
2192	}
2193
2194	/* When a client is running out of send credits, the credits are
2195	 * granted by the server's sending a packet using this queue.
2196	 * This avoids the situation that a clients cannot send packets
2197	 * for lack of credits
2198	 */
2199	smb_direct_wq = alloc_workqueue("ksmbd-smb_direct-wq",
2200					WQ_HIGHPRI | WQ_MEM_RECLAIM, 0);
2201	if (!smb_direct_wq)
2202		return -ENOMEM;
2203
2204	ret = smb_direct_listen(smb_direct_port);
2205	if (ret) {
2206		destroy_workqueue(smb_direct_wq);
2207		smb_direct_wq = NULL;
2208		pr_err("Can't listen: %d\n", ret);
2209		return ret;
2210	}
2211
2212	ksmbd_debug(RDMA, "init RDMA listener. cm_id=%p\n",
2213		    smb_direct_listener.cm_id);
2214	return 0;
2215}
2216
2217void ksmbd_rdma_destroy(void)
2218{
2219	if (!smb_direct_listener.cm_id)
2220		return;
2221
2222	ib_unregister_client(&smb_direct_ib_client);
2223	rdma_destroy_id(smb_direct_listener.cm_id);
2224
2225	smb_direct_listener.cm_id = NULL;
2226
2227	if (smb_direct_wq) {
2228		destroy_workqueue(smb_direct_wq);
2229		smb_direct_wq = NULL;
2230	}
2231}
2232
2233bool ksmbd_rdma_capable_netdev(struct net_device *netdev)
2234{
2235	struct smb_direct_device *smb_dev;
2236	int i;
2237	bool rdma_capable = false;
2238
2239	read_lock(&smb_direct_device_lock);
2240	list_for_each_entry(smb_dev, &smb_direct_device_list, list) {
2241		for (i = 0; i < smb_dev->ib_dev->phys_port_cnt; i++) {
2242			struct net_device *ndev;
2243
2244			if (smb_dev->ib_dev->ops.get_netdev) {
2245				ndev = smb_dev->ib_dev->ops.get_netdev(
2246					smb_dev->ib_dev, i + 1);
2247				if (!ndev)
2248					continue;
2249
2250				if (ndev == netdev) {
2251					dev_put(ndev);
2252					rdma_capable = true;
2253					goto out;
2254				}
2255				dev_put(ndev);
2256			/* if ib_dev does not implement ops.get_netdev
2257			 * check for matching infiniband GUID in hw_addr
2258			 */
2259			} else if (netdev->type == ARPHRD_INFINIBAND) {
2260				struct netdev_hw_addr *ha;
2261				union ib_gid gid;
2262				u32 port_num;
2263				int ret;
2264
2265				netdev_hw_addr_list_for_each(
2266					ha, &netdev->dev_addrs) {
2267					memcpy(&gid, ha->addr + 4, sizeof(gid));
2268					ret = ib_find_gid(smb_dev->ib_dev, &gid,
2269							  &port_num, NULL);
2270					if (!ret) {
2271						rdma_capable = true;
2272						goto out;
2273					}
2274				}
2275			}
2276		}
2277	}
2278out:
2279	read_unlock(&smb_direct_device_lock);
2280
2281	if (rdma_capable == false) {
2282		struct ib_device *ibdev;
2283
2284		ibdev = ib_device_get_by_netdev(netdev, RDMA_DRIVER_UNKNOWN);
2285		if (ibdev) {
2286			if (rdma_frwr_is_supported(&ibdev->attrs))
2287				rdma_capable = true;
2288			ib_device_put(ibdev);
2289		}
2290	}
2291
2292	return rdma_capable;
2293}
2294
2295static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = {
2296	.prepare	= smb_direct_prepare,
2297	.disconnect	= smb_direct_disconnect,
2298	.shutdown	= smb_direct_shutdown,
2299	.writev		= smb_direct_writev,
2300	.read		= smb_direct_read,
2301	.rdma_read	= smb_direct_rdma_read,
2302	.rdma_write	= smb_direct_rdma_write,
2303};
2304