1// RequestChannel.cpp
2
3#include "RequestChannel.h"
4
5#include <new>
6#include <typeinfo>
7
8#include <stdlib.h>
9#include <string.h>
10
11#include <AutoDeleter.h>
12#include <ByteOrder.h>
13
14#include "Channel.h"
15#include "Compatibility.h"
16#include "DebugSupport.h"
17#include "Request.h"
18#include "RequestDumper.h"
19#include "RequestFactory.h"
20#include "RequestFlattener.h"
21#include "RequestUnflattener.h"
22
23static const int32 kMaxSaneRequestSize	= 128 * 1024;	// 128 KB
24static const int32 kDefaultBufferSize	= 4096;			// 4 KB
25
26// ChannelWriter
27class RequestChannel::ChannelWriter : public Writer {
28public:
29	ChannelWriter(Channel* channel, void* buffer, int32 bufferSize)
30		: fChannel(channel),
31		  fBuffer(buffer),
32		  fBufferSize(bufferSize),
33		  fBytesWritten(0)
34	{
35	}
36
37	virtual status_t Write(const void* buffer, int32 size)
38	{
39		status_t error = B_OK;
40		// if the data don't fit into the buffer anymore, flush the buffer
41		if (fBytesWritten + size > fBufferSize) {
42			error = Flush();
43			if (error != B_OK)
44				return error;
45		}
46		// if the data don't even fit into an empty buffer, just send it,
47		// otherwise append it to the buffer
48		if (size > fBufferSize) {
49			error = fChannel->Send(buffer, size);
50			if (error != B_OK)
51				return error;
52		} else {
53			memcpy((uint8*)fBuffer + fBytesWritten, buffer, size);
54			fBytesWritten += size;
55		}
56		return error;
57	}
58
59	status_t Flush()
60	{
61		if (fBytesWritten == 0)
62			return B_OK;
63		status_t error = fChannel->Send(fBuffer, fBytesWritten);
64		if (error != B_OK)
65			return error;
66		fBytesWritten = 0;
67		return B_OK;
68	}
69
70private:
71	Channel*	fChannel;
72	void*		fBuffer;
73	int32		fBufferSize;
74	int32		fBytesWritten;
75};
76
77
78// MemoryReader
79class RequestChannel::MemoryReader : public Reader {
80public:
81	MemoryReader(void* buffer, int32 bufferSize)
82		: Reader(),
83		  fBuffer(buffer),
84		  fBufferSize(bufferSize),
85		  fBytesRead(0)
86	{
87	}
88
89	virtual status_t Read(void* buffer, int32 size)
90	{
91		// check parameters
92		if (!buffer || size < 0)
93			return B_BAD_VALUE;
94		if (size == 0)
95			return B_OK;
96		// get pointer into data buffer
97		void* localBuffer;
98		bool mustFree;
99		status_t error = Read(size, &localBuffer, &mustFree);
100		if (error != B_OK)
101			return error;
102		// copy data into supplied buffer
103		memcpy(buffer, localBuffer, size);
104		return B_OK;
105	}
106
107	virtual status_t Read(int32 size, void** buffer, bool* mustFree)
108	{
109		// check parameters
110		if (size < 0 || !buffer || !mustFree)
111			return B_BAD_VALUE;
112		if (fBytesRead + size > fBufferSize)
113			return B_BAD_VALUE;
114		// get the data pointer
115		*buffer = (uint8*)fBuffer + fBytesRead;
116		*mustFree = false;
117		fBytesRead += size;
118		return B_OK;
119	}
120
121	bool AllBytesRead() const
122	{
123		return (fBytesRead == fBufferSize);
124	}
125
126private:
127	void*	fBuffer;
128	int32	fBufferSize;
129	int32	fBytesRead;
130};
131
132
133// RequestHeader
134struct RequestChannel::RequestHeader {
135	uint32	type;
136	int32	size;
137};
138
139
140// constructor
141RequestChannel::RequestChannel(Channel* channel)
142	: fChannel(channel),
143	  fBuffer(NULL),
144	  fBufferSize(0)
145{
146	// allocate the send buffer
147	fBuffer = malloc(kDefaultBufferSize);
148	if (fBuffer)
149		fBufferSize = kDefaultBufferSize;
150}
151
152// destructor
153RequestChannel::~RequestChannel()
154{
155	free(fBuffer);
156}
157
158// SendRequest
159status_t
160RequestChannel::SendRequest(Request* request)
161{
162	if (!request)
163		RETURN_ERROR(B_BAD_VALUE);
164	PRINT("%p->RequestChannel::SendRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
165
166	// get request size
167	int32 size;
168	status_t error = _GetRequestSize(request, &size);
169	if (error != B_OK)
170		RETURN_ERROR(error);
171	if (size < 0 || size > kMaxSaneRequestSize) {
172		ERROR("RequestChannel::SendRequest(): ERROR: Invalid request size: "
173			"%" B_PRId32 "\n", size);
174		RETURN_ERROR(B_BAD_DATA);
175	}
176
177	// write the request header
178	RequestHeader header;
179	header.type = B_HOST_TO_BENDIAN_INT32(request->GetType());
180	header.size = B_HOST_TO_BENDIAN_INT32(size);
181	ChannelWriter writer(fChannel, fBuffer, fBufferSize);
182	error = writer.Write(&header, sizeof(RequestHeader));
183	if (error != B_OK)
184		RETURN_ERROR(error);
185
186	// now write the request in earnest
187	RequestFlattener flattener(&writer);
188	request->Flatten(&flattener);
189	error = flattener.GetStatus();
190	if (error != B_OK)
191		RETURN_ERROR(error);
192	error = writer.Flush();
193	RETURN_ERROR(error);
194}
195
196// ReceiveRequest
197status_t
198RequestChannel::ReceiveRequest(Request** _request)
199{
200	if (!_request)
201		RETURN_ERROR(B_BAD_VALUE);
202
203	// get the request header
204	RequestHeader header;
205	status_t error = fChannel->Receive(&header, sizeof(RequestHeader));
206	if (error != B_OK)
207		RETURN_ERROR(error);
208	header.type = B_HOST_TO_BENDIAN_INT32(header.type);
209	header.size = B_HOST_TO_BENDIAN_INT32(header.size);
210	if (header.size < 0 || header.size > kMaxSaneRequestSize) {
211		ERROR("RequestChannel::ReceiveRequest(): ERROR: Invalid request size: "
212			"%" B_PRId32 "\n", header.size);
213		RETURN_ERROR(B_BAD_DATA);
214	}
215
216	// create the request
217	Request* request;
218	error = RequestFactory::CreateRequest(header.type, &request);
219	if (error != B_OK)
220		RETURN_ERROR(error);
221	ObjectDeleter<Request> requestDeleter(request);
222
223	// allocate a buffer for the data and read them
224	if (header.size > 0) {
225		RequestBuffer* requestBuffer = RequestBuffer::Create(header.size);
226		if (!requestBuffer)
227			RETURN_ERROR(B_NO_MEMORY);
228		request->AttachBuffer(requestBuffer);
229
230		// receive the data
231		error = fChannel->Receive(requestBuffer->GetData(), header.size);
232		if (error != B_OK)
233			RETURN_ERROR(error);
234
235		// unflatten the request
236		MemoryReader reader(requestBuffer->GetData(), header.size);
237		RequestUnflattener unflattener(&reader);
238		request->Unflatten(&unflattener);
239		error = unflattener.GetStatus();
240		if (error != B_OK)
241			RETURN_ERROR(error);
242		if (!reader.AllBytesRead())
243			RETURN_ERROR(B_BAD_DATA);
244	}
245
246	requestDeleter.Detach();
247	*_request = request;
248	PRINT("%p->RequestChannel::ReceiveRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
249	return B_OK;
250}
251
252// _GetRequestSize
253status_t
254RequestChannel::_GetRequestSize(Request* request, int32* size)
255{
256	DummyWriter dummyWriter;
257	RequestFlattener flattener(&dummyWriter);
258	request->ShowAround(&flattener);
259	status_t error = flattener.GetStatus();
260	if (error != B_OK)
261		return error;
262	*size = flattener.GetBytesWritten();
263	return error;
264}
265
266