1/*
2   Unix SMB/CIFS implementation.
3   Packet handling
4   Copyright (C) Volker Lendecke 2007
5
6   This program is free software; you can redistribute it and/or modify
7   it under the terms of the GNU General Public License as published by
8   the Free Software Foundation; either version 3 of the License, or
9   (at your option) any later version.
10
11   This program is distributed in the hope that it will be useful,
12   but WITHOUT ANY WARRANTY; without even the implied warranty of
13   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14   GNU General Public License for more details.
15
16   You should have received a copy of the GNU General Public License
17   along with this program.  If not, see <http://www.gnu.org/licenses/>.
18*/
19
20#include "includes.h"
21
22struct packet_context {
23	int fd;
24	DATA_BLOB in, out;
25};
26
27/*
28 * Close the underlying fd
29 */
30static int packet_context_destructor(struct packet_context *ctx)
31{
32	return close(ctx->fd);
33}
34
35/*
36 * Initialize a packet context. The fd is given to the packet context, meaning
37 * that it is automatically closed when the packet context is freed.
38 */
39struct packet_context *packet_init(TALLOC_CTX *mem_ctx, int fd)
40{
41	struct packet_context *result;
42
43	if (!(result = TALLOC_ZERO_P(mem_ctx, struct packet_context))) {
44		return NULL;
45	}
46
47	result->fd = fd;
48	talloc_set_destructor(result, packet_context_destructor);
49	return result;
50}
51
52/*
53 * Pull data from the fd
54 */
55NTSTATUS packet_fd_read(struct packet_context *ctx)
56{
57	int res, available;
58	size_t new_size;
59	uint8 *in;
60
61	res = ioctl(ctx->fd, FIONREAD, &available);
62
63	if (res == -1) {
64		DEBUG(10, ("ioctl(FIONREAD) failed: %s\n", strerror(errno)));
65		return map_nt_error_from_unix(errno);
66	}
67
68	SMB_ASSERT(available >= 0);
69
70	if (available == 0) {
71		return NT_STATUS_END_OF_FILE;
72	}
73
74	new_size = ctx->in.length + available;
75
76	if (new_size < ctx->in.length) {
77		DEBUG(0, ("integer wrap\n"));
78		return NT_STATUS_NO_MEMORY;
79	}
80
81	if (!(in = TALLOC_REALLOC_ARRAY(ctx, ctx->in.data, uint8, new_size))) {
82		DEBUG(10, ("talloc failed\n"));
83		return NT_STATUS_NO_MEMORY;
84	}
85
86	ctx->in.data = in;
87
88	res = recv(ctx->fd, in + ctx->in.length, available, 0);
89
90	if (res < 0) {
91		DEBUG(10, ("recv failed: %s\n", strerror(errno)));
92		return map_nt_error_from_unix(errno);
93	}
94
95	if (res == 0) {
96		return NT_STATUS_END_OF_FILE;
97	}
98
99	ctx->in.length += res;
100
101	return NT_STATUS_OK;
102}
103
104NTSTATUS packet_fd_read_sync(struct packet_context *ctx,
105			     struct timeval *timeout)
106{
107	int res;
108	fd_set r_fds;
109
110	if (ctx->fd < 0 || ctx->fd >= FD_SETSIZE) {
111		errno = EBADF;
112		return map_nt_error_from_unix(errno);
113	}
114
115	FD_ZERO(&r_fds);
116	FD_SET(ctx->fd, &r_fds);
117
118	res = sys_select(ctx->fd+1, &r_fds, NULL, NULL, timeout);
119
120	if (res == 0) {
121		DEBUG(10, ("select timed out\n"));
122		return NT_STATUS_IO_TIMEOUT;
123	}
124
125	if (res == -1) {
126		DEBUG(10, ("select returned %s\n", strerror(errno)));
127		return map_nt_error_from_unix(errno);
128	}
129
130	return packet_fd_read(ctx);
131}
132
133bool packet_handler(struct packet_context *ctx,
134		    bool (*full_req)(const uint8_t *buf,
135				     size_t available,
136				     size_t *length,
137				     void *priv),
138		    NTSTATUS (*callback)(uint8_t *buf, size_t length,
139					 void *priv),
140		    void *priv, NTSTATUS *status)
141{
142	size_t length;
143	uint8_t *buf;
144
145	if (!full_req(ctx->in.data, ctx->in.length, &length, priv)) {
146		return False;
147	}
148
149	if (length > ctx->in.length) {
150		*status = NT_STATUS_INTERNAL_ERROR;
151		return true;
152	}
153
154	if (length == ctx->in.length) {
155		buf = ctx->in.data;
156		ctx->in.data = NULL;
157		ctx->in.length = 0;
158	} else {
159		buf = (uint8_t *)TALLOC_MEMDUP(ctx, ctx->in.data, length);
160		if (buf == NULL) {
161			*status = NT_STATUS_NO_MEMORY;
162			return true;
163		}
164
165		memmove(ctx->in.data, ctx->in.data + length,
166			ctx->in.length - length);
167		ctx->in.length -= length;
168	}
169
170	*status = callback(buf, length, priv);
171	return True;
172}
173
174/*
175 * How many bytes of outgoing data do we have pending?
176 */
177size_t packet_outgoing_bytes(struct packet_context *ctx)
178{
179	return ctx->out.length;
180}
181
182/*
183 * Push data to the fd
184 */
185NTSTATUS packet_fd_write(struct packet_context *ctx)
186{
187	ssize_t sent;
188
189	sent = send(ctx->fd, ctx->out.data, ctx->out.length, 0);
190
191	if (sent == -1) {
192		DEBUG(0, ("send failed: %s\n", strerror(errno)));
193		return map_nt_error_from_unix(errno);
194	}
195
196	memmove(ctx->out.data, ctx->out.data + sent,
197		ctx->out.length - sent);
198	ctx->out.length -= sent;
199
200	return NT_STATUS_OK;
201}
202
203/*
204 * Sync flush all outgoing bytes
205 */
206NTSTATUS packet_flush(struct packet_context *ctx)
207{
208	while (ctx->out.length != 0) {
209		NTSTATUS status = packet_fd_write(ctx);
210		if (!NT_STATUS_IS_OK(status)) {
211			return status;
212		}
213	}
214	return NT_STATUS_OK;
215}
216
217/*
218 * Send a list of DATA_BLOBs
219 *
220 * Example:  packet_send(ctx, 2, data_blob_const(&size, sizeof(size)),
221 *			 data_blob_const(buf, size));
222 */
223NTSTATUS packet_send(struct packet_context *ctx, int num_blobs, ...)
224{
225	va_list ap;
226	int i;
227	size_t len;
228	uint8 *out;
229
230	len = ctx->out.length;
231
232	va_start(ap, num_blobs);
233	for (i=0; i<num_blobs; i++) {
234		size_t tmp;
235		DATA_BLOB blob = va_arg(ap, DATA_BLOB);
236
237		tmp = len + blob.length;
238		if (tmp < len) {
239			DEBUG(0, ("integer overflow\n"));
240			va_end(ap);
241			return NT_STATUS_NO_MEMORY;
242		}
243		len = tmp;
244	}
245	va_end(ap);
246
247	if (len == 0) {
248		return NT_STATUS_OK;
249	}
250
251	if (!(out = TALLOC_REALLOC_ARRAY(ctx, ctx->out.data, uint8, len))) {
252		DEBUG(0, ("talloc failed\n"));
253		return NT_STATUS_NO_MEMORY;
254	}
255
256	ctx->out.data = out;
257
258	va_start(ap, num_blobs);
259	for (i=0; i<num_blobs; i++) {
260		DATA_BLOB blob = va_arg(ap, DATA_BLOB);
261
262		memcpy(ctx->out.data+ctx->out.length, blob.data, blob.length);
263		ctx->out.length += blob.length;
264	}
265	va_end(ap);
266
267	SMB_ASSERT(ctx->out.length == len);
268	return NT_STATUS_OK;
269}
270
271/*
272 * Get the packet context's file descriptor
273 */
274int packet_get_fd(struct packet_context *ctx)
275{
276	return ctx->fd;
277}
278
279