1#include "DNSQuery.h"
2
3#include <errno.h>
4#include <stdio.h>
5
6#include <ByteOrder.h>
7#include <FindDirectory.h>
8#include <NetAddress.h>
9#include <NetEndpoint.h>
10#include <Path.h>
11
12// #define DEBUG 1
13
14#undef PRINT
15#ifdef DEBUG
16#define PRINT(a...) printf(a)
17#else
18#define PRINT(a...)
19#endif
20
21
22static int32 gID = 1;
23
24
25BRawNetBuffer::BRawNetBuffer()
26{
27	_Init(NULL, 0);
28}
29
30
31BRawNetBuffer::BRawNetBuffer(off_t size)
32{
33	_Init(NULL, 0);
34	fBuffer.SetSize(size);
35}
36
37
38BRawNetBuffer::BRawNetBuffer(const void* buf, size_t size)
39{
40	_Init(buf, size);
41}
42
43
44status_t
45BRawNetBuffer::AppendUint16(uint16 value)
46{
47	uint16 netVal = B_HOST_TO_BENDIAN_INT16(value);
48	ssize_t sizeW = fBuffer.WriteAt(fWritePosition, &netVal, sizeof(uint16));
49	if (sizeW == B_NO_MEMORY)
50		return B_NO_MEMORY;
51	fWritePosition += sizeof(uint16);
52	return B_OK;
53}
54
55
56status_t
57BRawNetBuffer::AppendString(const char* string)
58{
59	size_t length = strlen(string) + 1;
60	ssize_t sizeW = fBuffer.WriteAt(fWritePosition, string, length);
61	if (sizeW == B_NO_MEMORY)
62		return B_NO_MEMORY;
63	fWritePosition += length;
64	return B_OK;
65}
66
67
68status_t
69BRawNetBuffer::ReadUint16(uint16& value)
70{
71	uint16 netVal;
72	ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint16));
73	if (sizeW == 0)
74		return B_ERROR;
75	value= B_BENDIAN_TO_HOST_INT16(netVal);
76	fReadPosition += sizeof(uint16);
77	return B_OK;
78}
79
80
81status_t
82BRawNetBuffer::ReadUint32(uint32& value)
83{
84	uint32 netVal;
85	ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint32));
86	if (sizeW == 0)
87		return B_ERROR;
88	value= B_BENDIAN_TO_HOST_INT32(netVal);
89	fReadPosition += sizeof(uint32);
90	return B_OK;
91}
92
93
94status_t
95BRawNetBuffer::ReadString(BString& string)
96{
97	string = "";
98	ssize_t bytesRead = _ReadStringAt(string, fReadPosition);
99	if (bytesRead < 0)
100		return B_ERROR;
101	fReadPosition += bytesRead;
102	return B_OK;
103}
104
105
106status_t
107BRawNetBuffer::SkipReading(off_t skip)
108{
109	if (fReadPosition + skip > (off_t)fBuffer.BufferLength())
110		return B_ERROR;
111	fReadPosition += skip;
112	return B_OK;
113}
114
115
116void
117BRawNetBuffer::_Init(const void* buf, size_t size)
118{
119	fWritePosition = 0;
120	fReadPosition = 0;
121	fBuffer.WriteAt(fWritePosition, buf, size);
122}
123
124
125ssize_t
126BRawNetBuffer::_ReadStringAt(BString& string, off_t pos)
127{
128	if (pos >= (off_t)fBuffer.BufferLength())
129		return -1;
130
131	ssize_t bytesRead = 0;
132	char* buffer = (char*)fBuffer.Buffer();
133	buffer = &buffer[pos];
134	// if the string is compressed we have to follow the links to the
135	// sub strings
136	while (pos < (off_t)fBuffer.BufferLength() && *buffer != 0) {
137		if (uint8(*buffer) == 192) {
138			// found a pointer mark
139			buffer++;
140			bytesRead++;
141			off_t subPos = uint8(*buffer);
142			_ReadStringAt(string, subPos);
143			break;
144		}
145		string.Append(buffer, 1);
146		buffer++;
147		bytesRead++;
148	}
149	bytesRead++;
150	return bytesRead;
151}
152
153
154// #pragma mark - DNSTools
155
156
157status_t
158DNSTools::GetDNSServers(BObjectList<BString>* serverList)
159{
160	// TODO: reading resolv.conf ourselves shouldn't be needed.
161	// we should have some function to retrieve the dns list
162#define	MATCH(line, name) \
163	(!strncmp(line, name, sizeof(name) - 1) && \
164	(line[sizeof(name) - 1] == ' ' || \
165	 line[sizeof(name) - 1] == '\t'))
166
167	BPath path;
168	if (find_directory(B_SYSTEM_SETTINGS_DIRECTORY, &path) != B_OK)
169		return B_ENTRY_NOT_FOUND;
170
171	path.Append("network/resolv.conf");
172
173	FILE* fp = fopen(path.Path(), "r");
174	if (fp == NULL) {
175		fprintf(stderr, "failed to open '%s' to read nameservers: %s\n",
176			path.Path(), strerror(errno));
177		return B_ENTRY_NOT_FOUND;
178	}
179
180	int nserv = 0;
181	char buf[1024];
182	char *cp; //, **pp;
183	int MAXNS = 2;
184
185	// read the config file
186	while (fgets(buf, sizeof(buf), fp) != NULL) {
187		// skip comments
188		if (*buf == ';' || *buf == '#')
189			continue;
190
191		// read nameservers to query
192		if (MATCH(buf, "nameserver") && nserv < MAXNS) {
193//			char sbuf[2];
194			cp = buf + sizeof("nameserver") - 1;
195			while (*cp == ' ' || *cp == '\t')
196				cp++;
197			cp[strcspn(cp, ";# \t\n")] = '\0';
198			if ((*cp != '\0') && (*cp != '\n')) {
199				serverList->AddItem(new BString(cp));
200				nserv++;
201			}
202		}
203		continue;
204	}
205
206	fclose(fp);
207
208	return B_OK;
209}
210
211
212BString
213DNSTools::ConvertToDNSName(const BString& string)
214{
215	BString outString = string;
216	int32 dot, lastDot, diff;
217
218	dot = string.FindFirst(".");
219	if (dot != B_ERROR) {
220		outString.Prepend((char*)&dot, 1);
221		// because we prepend a char add 1 more
222		lastDot = dot + 1;
223
224		while (true) {
225			dot = outString.FindFirst(".", lastDot + 1);
226			if (dot == B_ERROR)
227				break;
228
229			// set a counts to the dot
230			diff =  dot - 1 - lastDot;
231			outString.SetByteAt(lastDot, (char)diff);
232			lastDot = dot;
233		}
234	} else
235		lastDot = 0;
236
237	diff = outString.CountChars() - 1 - lastDot;
238	outString.SetByteAt(lastDot, (char)diff);
239
240	return outString;
241}
242
243
244BString
245DNSTools::ConvertFromDNSName(const BString& string)
246{
247	if (string.Length() == 0)
248		return string;
249
250	BString outString = string;
251	int32 dot = string[0];
252	int32 nextDot = dot;
253	outString.Remove(0, sizeof(char));
254	while (true) {
255		if (nextDot >= outString.Length())
256			break;
257		dot = outString[nextDot];
258		if (dot == 0)
259			break;
260		// set a "."
261		outString.SetByteAt(nextDot, '.');
262		nextDot+= dot + 1;
263	}
264	return outString;
265}
266
267
268// #pragma mark - DNSQuery
269// see http://tools.ietf.org/html/rfc1035 for more information about DNS
270
271
272DNSQuery::DNSQuery()
273{
274}
275
276
277DNSQuery::~DNSQuery()
278{
279}
280
281
282status_t
283DNSQuery::ReadDNSServer(in_addr* add)
284{
285	// list owns the items
286	BObjectList<BString> dnsServerList(5, true);
287	status_t status = DNSTools::GetDNSServers(&dnsServerList);
288	if (status != B_OK)
289		return status;
290
291	BString* firstDNS = dnsServerList.ItemAt(0);
292	if (firstDNS == NULL || inet_aton(firstDNS->String(), add) != 1)
293		return B_ERROR;
294
295	PRINT("dns server found: %s \n", firstDNS->String());
296	return B_OK;
297}
298
299
300status_t
301DNSQuery::GetMXRecords(const BString&  serverName,
302	BObjectList<mx_record>* mxList, bigtime_t timeout)
303{
304	// get the DNS server to ask for the mx record
305	in_addr dnsAddress;
306	if (ReadDNSServer(&dnsAddress) != B_OK)
307		return B_ERROR;
308
309	// create dns query package
310	BRawNetBuffer buffer;
311	dns_header header;
312	_SetMXHeader(&header);
313	_AppendQueryHeader(buffer, &header);
314
315	BString serverNameConv = DNSTools::ConvertToDNSName(serverName);
316	buffer.AppendString(serverNameConv);
317	buffer.AppendUint16(uint16(MX_RECORD));
318	buffer.AppendUint16(uint16(1));
319
320	// send the buffer
321	PRINT("send buffer\n");
322	BNetAddress netAddress(dnsAddress, 53);
323	BNetEndpoint netEndpoint(SOCK_DGRAM);
324	if (netEndpoint.InitCheck() != B_OK)
325		return B_ERROR;
326
327	if (netEndpoint.Connect(netAddress) != B_OK)
328		return B_ERROR;
329	PRINT("Connected\n");
330
331	int32 bytesSend = netEndpoint.Send(buffer.Data(), buffer.Size());
332	if (bytesSend == B_ERROR)
333		return B_ERROR;
334	PRINT("bytes send %i\n", int(bytesSend));
335
336	// receive buffer
337	BRawNetBuffer receiBuffer(512);
338	netEndpoint.SetTimeout(timeout);
339
340	int32 bytesRecei = netEndpoint.ReceiveFrom(receiBuffer.Data(), 512,
341		netAddress);
342	if (bytesRecei == B_ERROR)
343		return B_ERROR;
344	PRINT("bytes received %i\n", int(bytesRecei));
345
346	dns_header receiHeader;
347
348	_ReadQueryHeader(receiBuffer, &receiHeader);
349	PRINT("Package contains :");
350	PRINT("%d Questions, ", receiHeader.q_count);
351	PRINT("%d Answers, ", receiHeader.ans_count);
352	PRINT("%d Authoritative Servers, ", receiHeader.auth_count);
353	PRINT("%d Additional records\n", receiHeader.add_count);
354
355	// remove name and Question
356	BString dummyS;
357	uint16 dummy;
358	receiBuffer.ReadString(dummyS);
359	receiBuffer.ReadUint16(dummy);
360	receiBuffer.ReadUint16(dummy);
361
362	bool mxRecordFound = false;
363	for (int i = 0; i < receiHeader.ans_count; i++) {
364		resource_record_head rrHead;
365		_ReadResourceRecord(receiBuffer, &rrHead);
366		if (rrHead.type == MX_RECORD) {
367			mx_record* mxRec = new mx_record;
368			_ReadMXRecord(receiBuffer, mxRec);
369			PRINT("MX record found pri %i, name %s\n",
370				mxRec->priority, mxRec->serverName.String());
371			// Add mx record to the list
372			mxList->AddItem(mxRec);
373			mxRecordFound = true;
374		} else {
375			buffer.SkipReading(rrHead.dataLength);
376		}
377	}
378
379	if (!mxRecordFound)
380		return B_ERROR;
381
382	return B_OK;
383}
384
385
386uint16
387DNSQuery::_GetUniqueID()
388{
389	int32 nextId= atomic_add(&gID, 1);
390	// just to be sure
391	if (nextId > 65529)
392		nextId = 0;
393	return nextId;
394}
395
396
397void
398DNSQuery::_SetMXHeader(dns_header* header)
399{
400	header->id = _GetUniqueID();
401	header->qr = 0;      //This is a query
402	header->opcode = 0;  //This is a standard query
403	header->aa = 0;      //Not Authoritative
404	header->tc = 0;      //This message is not truncated
405	header->rd = 1;      //Recursion Desired
406	header->ra = 0;      //Recursion not available! hey we dont have it (lol)
407	header->z  = 0;
408	header->rcode = 0;
409	header->q_count = 1;   //we have only 1 question
410	header->ans_count  = 0;
411	header->auth_count = 0;
412	header->add_count  = 0;
413}
414
415
416void
417DNSQuery::_AppendQueryHeader(BRawNetBuffer& buffer, const dns_header* header)
418{
419	buffer.AppendUint16(header->id);
420	uint16 data = 0;
421	data |= header->rcode;
422	data |= header->z << 4;
423	data |= header->ra << 7;
424	data |= header->rd << 8;
425	data |= header->tc << 9;
426	data |= header->aa << 10;
427	data |= header->opcode << 11;
428	data |= header->qr << 15;
429	buffer.AppendUint16(data);
430	buffer.AppendUint16(header->q_count);
431	buffer.AppendUint16(header->ans_count);
432	buffer.AppendUint16(header->auth_count);
433	buffer.AppendUint16(header->add_count);
434}
435
436
437void
438DNSQuery::_ReadQueryHeader(BRawNetBuffer& buffer, dns_header* header)
439{
440	buffer.ReadUint16(header->id);
441	uint16 data = 0;
442	buffer.ReadUint16(data);
443	header->rcode = data & 0x0F;
444	header->z = (data >> 4) & 0x07;
445	header->ra = (data >> 7) & 0x01;
446	header->rd = (data >> 8) & 0x01;
447	header->tc = (data >> 9) & 0x01;
448	header->aa = (data >> 10) & 0x01;
449	header->opcode = (data >> 11) & 0x0F;
450	header->qr = (data >> 15) & 0x01;
451	buffer.ReadUint16(header->q_count);
452	buffer.ReadUint16(header->ans_count);
453	buffer.ReadUint16(header->auth_count);
454	buffer.ReadUint16(header->add_count);
455}
456
457
458void
459DNSQuery::_ReadMXRecord(BRawNetBuffer& buffer, mx_record* mxRecord)
460{
461	buffer.ReadUint16(mxRecord->priority);
462	buffer.ReadString(mxRecord->serverName);
463	mxRecord->serverName = DNSTools::ConvertFromDNSName(mxRecord->serverName);
464}
465
466
467void
468DNSQuery::_ReadResourceRecord(BRawNetBuffer& buffer,
469	resource_record_head *rrHead)
470{
471	buffer.ReadString(rrHead->name);
472	buffer.ReadUint16(rrHead->type);
473	buffer.ReadUint16(rrHead->dataClass);
474	buffer.ReadUint32(rrHead->ttl);
475	buffer.ReadUint16(rrHead->dataLength);
476}
477