1/*	$OpenBSD: handshake_table.c,v 1.18 2022/12/01 13:49:12 tb Exp $	*/
2/*
3 * Copyright (c) 2019 Theo Buehler <tb@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18#include <err.h>
19#include <stdint.h>
20#include <stdio.h>
21#include <stdlib.h>
22#include <unistd.h>
23
24#include "tls13_handshake.h"
25
26#define MAX_FLAGS (UINT8_MAX + 1)
27
28/*
29 * From RFC 8446:
30 *
31 * Appendix A.  State Machine
32 *
33 *    This appendix provides a summary of the legal state transitions for
34 *    the client and server handshakes.  State names (in all capitals,
35 *    e.g., START) have no formal meaning but are provided for ease of
36 *    comprehension.  Actions which are taken only in certain circumstances
37 *    are indicated in [].  The notation "K_{send,recv} = foo" means "set
38 *    the send/recv key to the given key".
39 *
40 * A.1.  Client
41 *
42 *                               START <----+
43 *                Send ClientHello |        | Recv HelloRetryRequest
44 *           [K_send = early data] |        |
45 *                                 v        |
46 *            /                 WAIT_SH ----+
47 *            |                    | Recv ServerHello
48 *            |                    | K_recv = handshake
49 *        Can |                    V
50 *       send |                 WAIT_EE
51 *      early |                    | Recv EncryptedExtensions
52 *       data |           +--------+--------+
53 *            |     Using |                 | Using certificate
54 *            |       PSK |                 v
55 *            |           |            WAIT_CERT_CR
56 *            |           |        Recv |       | Recv CertificateRequest
57 *            |           | Certificate |       v
58 *            |           |             |    WAIT_CERT
59 *            |           |             |       | Recv Certificate
60 *            |           |             v       v
61 *            |           |              WAIT_CV
62 *            |           |                 | Recv CertificateVerify
63 *            |           +> WAIT_FINISHED <+
64 *            |                  | Recv Finished
65 *            \                  | [Send EndOfEarlyData]
66 *                               | K_send = handshake
67 *                               | [Send Certificate [+ CertificateVerify]]
68 *     Can send                  | Send Finished
69 *     app data   -->            | K_send = K_recv = application
70 *     after here                v
71 *                           CONNECTED
72 *
73 *    Note that with the transitions as shown above, clients may send
74 *    alerts that derive from post-ServerHello messages in the clear or
75 *    with the early data keys.  If clients need to send such alerts, they
76 *    SHOULD first rekey to the handshake keys if possible.
77 *
78 */
79
80struct child {
81	enum tls13_message_type	mt;
82	uint8_t			flag;
83	uint8_t			forced;
84	uint8_t			illegal;
85};
86
87static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
88	[CLIENT_HELLO] = {
89		{
90			.mt = SERVER_HELLO_RETRY_REQUEST,
91		},
92		{
93			.mt = SERVER_HELLO,
94			.flag = WITHOUT_HRR,
95		},
96	},
97	[SERVER_HELLO_RETRY_REQUEST] = {
98		{
99			.mt = CLIENT_HELLO_RETRY,
100		},
101	},
102	[CLIENT_HELLO_RETRY] = {
103		{
104			.mt = SERVER_HELLO,
105		},
106	},
107	[SERVER_HELLO] = {
108		{
109			.mt = SERVER_ENCRYPTED_EXTENSIONS,
110		},
111	},
112	[SERVER_ENCRYPTED_EXTENSIONS] = {
113		{
114			.mt = SERVER_CERTIFICATE_REQUEST,
115		},
116		{	.mt = SERVER_CERTIFICATE,
117			.flag = WITHOUT_CR,
118		},
119		{
120			.mt = SERVER_FINISHED,
121			.flag = WITH_PSK,
122		},
123	},
124	[SERVER_CERTIFICATE_REQUEST] = {
125		{
126			.mt = SERVER_CERTIFICATE,
127		},
128	},
129	[SERVER_CERTIFICATE] = {
130		{
131			.mt = SERVER_CERTIFICATE_VERIFY,
132		},
133	},
134	[SERVER_CERTIFICATE_VERIFY] = {
135		{
136			.mt = SERVER_FINISHED,
137		},
138	},
139	[SERVER_FINISHED] = {
140		{
141			.mt = CLIENT_FINISHED,
142			.forced = WITHOUT_CR | WITH_PSK,
143		},
144		{
145			.mt = CLIENT_CERTIFICATE,
146			.illegal = WITHOUT_CR | WITH_PSK,
147		},
148	},
149	[CLIENT_CERTIFICATE] = {
150		{
151			.mt = CLIENT_FINISHED,
152		},
153		{
154			.mt = CLIENT_CERTIFICATE_VERIFY,
155			.flag = WITH_CCV,
156		},
157	},
158	[CLIENT_CERTIFICATE_VERIFY] = {
159		{
160			.mt = CLIENT_FINISHED,
161		},
162	},
163	[CLIENT_FINISHED] = {
164		{
165			.mt = APPLICATION_DATA,
166		},
167	},
168	[APPLICATION_DATA] = {
169		{
170			.mt = 0,
171		},
172	},
173};
174
175const size_t	 stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
176
177void		 build_table(enum tls13_message_type
178		     table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
179		     struct child current, struct child end,
180		     struct child path[], uint8_t flags, unsigned int depth);
181size_t		 count_handshakes(void);
182void		 edge(enum tls13_message_type start,
183		     enum tls13_message_type end, uint8_t flag);
184const char	*flag2str(uint8_t flag);
185void		 flag_label(uint8_t flag);
186void		 forced_edges(enum tls13_message_type start,
187		     enum tls13_message_type end, uint8_t forced);
188int		 generate_graphics(void);
189void		 fprint_entry(FILE *stream,
190		     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
191		     uint8_t flags);
192void		 fprint_flags(FILE *stream, uint8_t flags);
193const char	*mt2str(enum tls13_message_type mt);
194void		 usage(void);
195int		 verify_table(enum tls13_message_type
196		     table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print);
197
198const char *
199flag2str(uint8_t flag)
200{
201	const char *ret;
202
203	if (flag & (flag - 1))
204		errx(1, "more than one bit is set");
205
206	switch (flag) {
207	case INITIAL:
208		ret = "INITIAL";
209		break;
210	case NEGOTIATED:
211		ret = "NEGOTIATED";
212		break;
213	case WITHOUT_CR:
214		ret = "WITHOUT_CR";
215		break;
216	case WITHOUT_HRR:
217		ret = "WITHOUT_HRR";
218		break;
219	case WITH_PSK:
220		ret = "WITH_PSK";
221		break;
222	case WITH_CCV:
223		ret = "WITH_CCV";
224		break;
225	case WITH_0RTT:
226		ret = "WITH_0RTT";
227		break;
228	default:
229		ret = "UNKNOWN";
230	}
231
232	return ret;
233}
234
235const char *
236mt2str(enum tls13_message_type mt)
237{
238	const char *ret;
239
240	switch (mt) {
241	case INVALID:
242		ret = "INVALID";
243		break;
244	case CLIENT_HELLO:
245		ret = "CLIENT_HELLO";
246		break;
247	case CLIENT_HELLO_RETRY:
248		ret = "CLIENT_HELLO_RETRY";
249		break;
250	case CLIENT_END_OF_EARLY_DATA:
251		ret = "CLIENT_END_OF_EARLY_DATA";
252		break;
253	case CLIENT_CERTIFICATE:
254		ret = "CLIENT_CERTIFICATE";
255		break;
256	case CLIENT_CERTIFICATE_VERIFY:
257		ret = "CLIENT_CERTIFICATE_VERIFY";
258		break;
259	case CLIENT_FINISHED:
260		ret = "CLIENT_FINISHED";
261		break;
262	case SERVER_HELLO:
263		ret = "SERVER_HELLO";
264		break;
265	case SERVER_HELLO_RETRY_REQUEST:
266		ret = "SERVER_HELLO_RETRY_REQUEST";
267		break;
268	case SERVER_ENCRYPTED_EXTENSIONS:
269		ret = "SERVER_ENCRYPTED_EXTENSIONS";
270		break;
271	case SERVER_CERTIFICATE:
272		ret = "SERVER_CERTIFICATE";
273		break;
274	case SERVER_CERTIFICATE_VERIFY:
275		ret = "SERVER_CERTIFICATE_VERIFY";
276		break;
277	case SERVER_CERTIFICATE_REQUEST:
278		ret = "SERVER_CERTIFICATE_REQUEST";
279		break;
280	case SERVER_FINISHED:
281		ret = "SERVER_FINISHED";
282		break;
283	case APPLICATION_DATA:
284		ret = "APPLICATION_DATA";
285		break;
286	case TLS13_NUM_MESSAGE_TYPES:
287		ret = "TLS13_NUM_MESSAGE_TYPES";
288		break;
289	default:
290		ret = "UNKNOWN";
291		break;
292	}
293
294	return ret;
295}
296
297void
298fprint_flags(FILE *stream, uint8_t flags)
299{
300	int first = 1, i;
301
302	if (flags == 0) {
303		fprintf(stream, "%s", flag2str(flags));
304		return;
305	}
306
307	for (i = 0; i < 8; i++) {
308		uint8_t set = flags & (1U << i);
309
310		if (set) {
311			fprintf(stream, "%s%s", first ? "" : " | ",
312			    flag2str(set));
313			first = 0;
314		}
315	}
316}
317
318void
319fprint_entry(FILE *stream,
320    enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
321{
322	int i;
323
324	fprintf(stream, "\t[");
325	fprint_flags(stream, flags);
326	fprintf(stream, "] = {\n");
327
328	for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
329		if (path[i] == 0)
330			break;
331		fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
332	}
333	fprintf(stream, "\t},\n");
334}
335
336void
337edge(enum tls13_message_type start, enum tls13_message_type end,
338    uint8_t flag)
339{
340	printf("\t%s -> %s", mt2str(start), mt2str(end));
341	flag_label(flag);
342	printf(";\n");
343}
344
345void
346flag_label(uint8_t flag)
347{
348	if (flag)
349		printf(" [label=\"%s\"]", flag2str(flag));
350}
351
352void
353forced_edges(enum tls13_message_type start, enum tls13_message_type end,
354    uint8_t forced)
355{
356	uint8_t	forced_flag, i;
357
358	if (forced == 0)
359		return;
360
361	for (i = 0; i < 8; i++) {
362		forced_flag = forced & (1U << i);
363		if (forced_flag)
364			edge(start, end, forced_flag);
365	}
366}
367
368int
369generate_graphics(void)
370{
371	enum tls13_message_type	start, end;
372	unsigned int		child;
373	uint8_t			flag;
374	uint8_t			forced;
375
376	printf("digraph G {\n");
377	printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
378	printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
379
380	for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
381		for (child = 0; stateinfo[start][child].mt != 0; child++) {
382			end = stateinfo[start][child].mt;
383			flag = stateinfo[start][child].flag;
384			forced = stateinfo[start][child].forced;
385
386			if (forced == 0)
387				edge(start, end, flag);
388			else
389				forced_edges(start, end, forced);
390		}
391	}
392
393	printf("}\n");
394	return 0;
395}
396
397extern enum tls13_message_type	handshakes[][TLS13_NUM_MESSAGE_TYPES];
398extern size_t			handshake_count;
399
400size_t
401count_handshakes(void)
402{
403	size_t	ret = 0, i;
404
405	for (i = 0; i < handshake_count; i++) {
406		if (handshakes[i][0] != INVALID)
407			ret++;
408	}
409
410	return ret;
411}
412
413void
414build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
415    struct child current, struct child end, struct child path[], uint8_t flags,
416    unsigned int depth)
417{
418	unsigned int i;
419
420	if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
421		errx(1, "recursed too deeply");
422
423	/* Record current node. */
424	path[depth++] = current;
425	flags |= current.flag;
426
427	/* If we haven't reached the end, recurse over the children. */
428	if (current.mt != end.mt) {
429		for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
430			struct child child = stateinfo[current.mt][i];
431			int forced = stateinfo[current.mt][i].forced;
432			int illegal = stateinfo[current.mt][i].illegal;
433
434			if ((forced == 0 || (forced & flags)) &&
435			    (illegal == 0 || !(illegal & flags)))
436				build_table(table, child, end, path, flags,
437				    depth);
438		}
439		return;
440	}
441
442	if (flags == 0)
443		errx(1, "path does not set flags");
444
445	if (table[flags][0] != 0)
446		errx(1, "path traversed twice");
447
448	for (i = 0; i < depth; i++)
449		table[flags][i] = path[i].mt;
450}
451
452int
453verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
454    int print)
455{
456	int	success = 1, i;
457	size_t	num_valid, num_found = 0;
458	uint8_t	flags = 0;
459
460	do {
461		if (table[flags][0] == 0)
462			continue;
463
464		num_found++;
465
466		for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
467			if (table[flags][i] != handshakes[flags][i]) {
468				fprintf(stderr,
469				    "incorrect entry %d of handshake ", i);
470				fprint_flags(stderr, flags);
471				fprintf(stderr, "\n");
472				success = 0;
473			}
474		}
475
476		if (print)
477			fprint_entry(stdout, table[flags], flags);
478	} while(++flags != 0);
479
480	num_valid = count_handshakes();
481	if (num_valid != num_found) {
482		fprintf(stderr,
483		    "incorrect number of handshakes: want %zu, got %zu.\n",
484		    num_valid, num_found);
485		success = 0;
486	}
487
488	return success;
489}
490
491void
492usage(void)
493{
494	fprintf(stderr, "usage: handshake_table [-C | -g]\n");
495	exit(1);
496}
497
498int
499main(int argc, char *argv[])
500{
501	static enum tls13_message_type
502	    hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = {
503		[INITIAL] = {
504			CLIENT_HELLO,
505			SERVER_HELLO_RETRY_REQUEST,
506			CLIENT_HELLO_RETRY,
507			SERVER_HELLO,
508		},
509	};
510	struct child	start = {
511		.mt = CLIENT_HELLO,
512	};
513	struct child	end = {
514		.mt = APPLICATION_DATA,
515	};
516	struct child	path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
517	uint8_t		flags = NEGOTIATED;
518	unsigned int	depth = 0;
519	int		ch, graphviz = 0, print = 0;
520
521	while ((ch = getopt(argc, argv, "Cg")) != -1) {
522		switch (ch) {
523		case 'C':
524			print = 1;
525			break;
526		case 'g':
527			graphviz = 1;
528			break;
529		default:
530			usage();
531		}
532	}
533	argc -= optind;
534	argv += optind;
535
536	if (argc != 0)
537		usage();
538
539	if (graphviz && print)
540		usage();
541
542	if (graphviz)
543		return generate_graphics();
544
545	build_table(hs_table, start, end, path, flags, depth);
546	if (!verify_table(hs_table, print))
547		return 1;
548
549	return 0;
550}
551