1/*	$OpenBSD: radius.c,v 1.4 2023/07/08 08:53:26 yasuoka Exp $ */
2
3/*-
4 * Copyright (c) 2009 Internet Initiative Japan Inc.
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE"AUTHOR" AND CONTRIBUTORS AS IS'' AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29#include <sys/socket.h>
30#include <sys/uio.h>
31#include <arpa/inet.h>
32
33#include <stdint.h>
34#include <stdio.h>
35#include <stdlib.h>
36#include <string.h>
37
38#include <openssl/md5.h>
39
40#include "radius.h"
41
42#include "radius_local.h"
43
44static uint8_t radius_id_counter = 0;
45
46static int
47radius_check_packet_data(const RADIUS_PACKET_DATA * pdata, size_t length)
48{
49	const RADIUS_ATTRIBUTE	*attr;
50	const RADIUS_ATTRIBUTE	*end;
51
52	if (length < sizeof(RADIUS_PACKET_DATA))
53		return (-1);
54	if (length > 0xffff)
55		return (-1);
56	if (length != (size_t) ntohs(pdata->length))
57		return (-1);
58
59	attr = ATTRS_BEGIN(pdata);
60	end = ATTRS_END(pdata);
61	for (; attr < end; ATTRS_ADVANCE(attr)) {
62		if (attr->length < 2)
63			return (-1);
64		if (attr->type == RADIUS_TYPE_VENDOR_SPECIFIC) {
65			if (attr->length < 8)
66				return (-1);
67			if ((attr->vendor & htonl(0xff000000U)) != 0)
68				return (-1);
69			if (attr->length != attr->vlength + 6)
70				return (-1);
71		}
72	}
73
74	if (attr != end)
75		return (-1);
76
77	return (0);
78}
79
80int
81radius_ensure_add_capacity(RADIUS_PACKET * packet, size_t capacity)
82{
83	size_t	 newsize;
84	void	*newptr;
85
86	/*
87	 * The maximum size is 64KB.
88	 * We use little bit smaller value for our safety(?).
89	 */
90	if (ntohs(packet->pdata->length) + capacity > 0xfe00)
91		return (-1);
92
93	if (ntohs(packet->pdata->length) + capacity > packet->capacity) {
94		newsize = ntohs(packet->pdata->length) + capacity +
95		    RADIUS_PACKET_CAPACITY_INCREMENT;
96		newptr = realloc(packet->pdata, newsize);
97		if (newptr == NULL)
98			return (-1);
99		packet->capacity = newsize;
100		packet->pdata = (RADIUS_PACKET_DATA *)newptr;
101	}
102	return (0);
103}
104
105RADIUS_PACKET *
106radius_new_request_packet(uint8_t code)
107{
108	RADIUS_PACKET	*packet;
109
110	packet = malloc(sizeof(RADIUS_PACKET));
111	if (packet == NULL)
112		return (NULL);
113	packet->pdata = malloc(RADIUS_PACKET_CAPACITY_INITIAL);
114	if (packet->pdata == NULL) {
115		free(packet);
116		return (NULL);
117	}
118	packet->capacity = RADIUS_PACKET_CAPACITY_INITIAL;
119	packet->request = NULL;
120	packet->pdata->code = code;
121	packet->pdata->id = radius_id_counter++;
122	packet->pdata->length = htons(sizeof(RADIUS_PACKET_DATA));
123	arc4random_buf(packet->pdata->authenticator,
124	    sizeof(packet->pdata->authenticator));
125
126	return (packet);
127}
128
129RADIUS_PACKET *
130radius_new_response_packet(uint8_t code, const RADIUS_PACKET * request)
131{
132	RADIUS_PACKET	*packet;
133
134	packet = radius_new_request_packet(code);
135	if (packet == NULL)
136		return (NULL);
137	packet->request = request;
138	packet->pdata->id = request->pdata->id;
139
140	return (packet);
141}
142
143RADIUS_PACKET *
144radius_convert_packet(const void *pdata, size_t length)
145{
146	RADIUS_PACKET *packet;
147
148	if (radius_check_packet_data((const RADIUS_PACKET_DATA *)pdata,
149	    length) != 0)
150		return (NULL);
151	packet = malloc(sizeof(RADIUS_PACKET));
152	if (packet == NULL)
153		return (NULL);
154	packet->pdata = malloc(length);
155	packet->capacity = length;
156	packet->request = NULL;
157	if (packet->pdata == NULL) {
158		free(packet);
159		return (NULL);
160	}
161	memcpy(packet->pdata, pdata, length);
162
163	return (packet);
164}
165
166int
167radius_delete_packet(RADIUS_PACKET * packet)
168{
169	free(packet->pdata);
170	free(packet);
171	return (0);
172}
173
174uint8_t
175radius_get_code(const RADIUS_PACKET * packet)
176{
177	return (packet->pdata->code);
178}
179
180uint8_t
181radius_get_id(const RADIUS_PACKET * packet)
182{
183	return (packet->pdata->id);
184}
185
186void
187radius_update_id(RADIUS_PACKET * packet)
188{
189	packet->pdata->id = radius_id_counter++;
190}
191
192void
193radius_set_id(RADIUS_PACKET * packet, uint8_t id)
194{
195	packet->pdata->id = id;
196}
197
198void
199radius_get_authenticator(const RADIUS_PACKET * packet, void *authenticator)
200{
201	memcpy(authenticator, packet->pdata->authenticator, 16);
202}
203
204uint8_t *
205radius_get_authenticator_retval(const RADIUS_PACKET * packet)
206{
207	return (packet->pdata->authenticator);
208}
209
210uint8_t *
211radius_get_request_authenticator_retval(const RADIUS_PACKET * packet)
212{
213	if (packet->request == NULL)
214		return (packet->pdata->authenticator);
215	else
216		return (packet->request->pdata->authenticator);
217}
218
219void
220radius_set_request_packet(RADIUS_PACKET * packet,
221    const RADIUS_PACKET * request)
222{
223	packet->request = request;
224}
225
226const RADIUS_PACKET *
227radius_get_request_packet(const RADIUS_PACKET * packet)
228{
229	return (packet->request);
230}
231
232static void
233radius_calc_authenticator(uint8_t * authenticator_dst,
234    const RADIUS_PACKET * packet, const uint8_t * authenticator_src,
235    const char *secret)
236{
237	MD5_CTX	 ctx;
238
239	MD5_Init(&ctx);
240	MD5_Update(&ctx, (unsigned char *)packet->pdata, 4);
241	MD5_Update(&ctx, (unsigned char *)authenticator_src, 16);
242	MD5_Update(&ctx,
243	    (unsigned char *)packet->pdata->attributes,
244	    radius_get_length(packet) - 20);
245	MD5_Update(&ctx, (unsigned char *)secret, strlen(secret));
246	MD5_Final((unsigned char *)authenticator_dst, &ctx);
247}
248
249static void
250radius_calc_response_authenticator(uint8_t * authenticator_dst,
251    const RADIUS_PACKET * packet, const char *secret)
252{
253	radius_calc_authenticator(authenticator_dst,
254	    packet, packet->request->pdata->authenticator, secret);
255}
256
257int
258radius_check_response_authenticator(const RADIUS_PACKET * packet,
259    const char *secret)
260{
261	uint8_t authenticator[16];
262
263	radius_calc_response_authenticator(authenticator, packet, secret);
264	return (memcmp(authenticator, packet->pdata->authenticator, 16));
265}
266
267void
268radius_set_response_authenticator(RADIUS_PACKET * packet,
269    const char *secret)
270{
271	radius_calc_response_authenticator(packet->pdata->authenticator,
272	    packet, secret);
273}
274
275static void
276radius_calc_accounting_request_authenticator(uint8_t * authenticator_dst,
277    const RADIUS_PACKET * packet, const char *secret)
278{
279	uint8_t	 zero[16];
280
281	memset(zero, 0, sizeof(zero));
282	radius_calc_authenticator(authenticator_dst,
283	    packet, zero, secret);
284}
285
286void
287radius_set_accounting_request_authenticator(RADIUS_PACKET * packet,
288    const char *secret)
289{
290	radius_calc_accounting_request_authenticator(
291	    packet->pdata->authenticator, packet, secret);
292}
293
294int
295radius_check_accounting_request_authenticator(const RADIUS_PACKET * packet,
296    const char *secret)
297{
298	uint8_t authenticator[16];
299
300	radius_calc_accounting_request_authenticator(authenticator, packet,
301	    secret);
302	return (memcmp(authenticator, packet->pdata->authenticator, 16));
303}
304
305
306uint16_t
307radius_get_length(const RADIUS_PACKET * packet)
308{
309	return (ntohs(packet->pdata->length));
310}
311
312
313const void *
314radius_get_data(const RADIUS_PACKET * packet)
315{
316	return (packet->pdata);
317}
318
319RADIUS_PACKET *
320radius_recvfrom(int s, int flags, struct sockaddr * sa, socklen_t * slen)
321{
322	char	 buf[0x10000];
323	ssize_t	 n;
324
325	n = recvfrom(s, buf, sizeof(buf), flags, sa, slen);
326	if (n <= 0)
327		return (NULL);
328
329	return (radius_convert_packet(buf, (size_t) n));
330}
331
332int
333radius_sendto(int s, const RADIUS_PACKET * packet,
334    int flags, const struct sockaddr * sa, socklen_t slen)
335{
336	ssize_t	 n;
337
338	n = sendto(s, packet->pdata, radius_get_length(packet), flags, sa,
339	    slen);
340	if (n != radius_get_length(packet))
341		return (-1);
342
343	return (0);
344}
345
346RADIUS_PACKET *
347radius_recv(int s, int flags)
348{
349	char	 buf[0x10000];
350	ssize_t	 n;
351
352	n = recv(s, buf, sizeof(buf), flags);
353	if (n <= 0)
354		return (NULL);
355
356	return (radius_convert_packet(buf, (size_t) n));
357}
358
359int
360radius_send(int s, const RADIUS_PACKET * packet, int flags)
361{
362	ssize_t	 n;
363
364	n = send(s, packet->pdata, radius_get_length(packet), flags);
365	if (n != radius_get_length(packet))
366		return (-1);
367
368	return (0);
369}
370
371RADIUS_PACKET *
372radius_recvmsg(int s, struct msghdr * msg, int flags)
373{
374	struct iovec	 iov;
375	char		 buf[0x10000];
376	ssize_t		 n;
377
378	if (msg->msg_iov != NULL || msg->msg_iovlen != 0)
379		return (NULL);
380
381	iov.iov_base = buf;
382	iov.iov_len = sizeof(buf);
383	msg->msg_iov = &iov;
384	msg->msg_iovlen = 1;
385	n = recvmsg(s, msg, flags);
386	msg->msg_iov = NULL;
387	msg->msg_iovlen = 0;
388	if (n <= 0)
389		return (NULL);
390
391	return (radius_convert_packet(buf, (size_t) n));
392}
393
394int
395radius_sendmsg(int s, const RADIUS_PACKET * packet,
396    const struct msghdr * msg, int flags)
397{
398	struct msghdr	 msg0;
399	struct iovec	 iov;
400	ssize_t		 n;
401
402	if (msg->msg_iov != NULL || msg->msg_iovlen != 0)
403		return (-1);
404
405	iov.iov_base = packet->pdata;
406	iov.iov_len = radius_get_length(packet);
407	msg0 = *msg;
408	msg0.msg_iov = &iov;
409	msg0.msg_iovlen = 1;
410	n = sendmsg(s, &msg0, flags);
411	if (n != radius_get_length(packet))
412		return (-1);
413
414	return (0);
415}
416