/* $OpenBSD: handshake_table.c,v 1.18 2022/12/01 13:49:12 tb Exp $ */ /* * Copyright (c) 2019 Theo Buehler * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ #include #include #include #include #include #include "tls13_handshake.h" #define MAX_FLAGS (UINT8_MAX + 1) /* * From RFC 8446: * * Appendix A. State Machine * * This appendix provides a summary of the legal state transitions for * the client and server handshakes. State names (in all capitals, * e.g., START) have no formal meaning but are provided for ease of * comprehension. Actions which are taken only in certain circumstances * are indicated in []. The notation "K_{send,recv} = foo" means "set * the send/recv key to the given key". * * A.1. Client * * START <----+ * Send ClientHello | | Recv HelloRetryRequest * [K_send = early data] | | * v | * / WAIT_SH ----+ * | | Recv ServerHello * | | K_recv = handshake * Can | V * send | WAIT_EE * early | | Recv EncryptedExtensions * data | +--------+--------+ * | Using | | Using certificate * | PSK | v * | | WAIT_CERT_CR * | | Recv | | Recv CertificateRequest * | | Certificate | v * | | | WAIT_CERT * | | | | Recv Certificate * | | v v * | | WAIT_CV * | | | Recv CertificateVerify * | +> WAIT_FINISHED <+ * | | Recv Finished * \ | [Send EndOfEarlyData] * | K_send = handshake * | [Send Certificate [+ CertificateVerify]] * Can send | Send Finished * app data --> | K_send = K_recv = application * after here v * CONNECTED * * Note that with the transitions as shown above, clients may send * alerts that derive from post-ServerHello messages in the clear or * with the early data keys. If clients need to send such alerts, they * SHOULD first rekey to the handshake keys if possible. * */ struct child { enum tls13_message_type mt; uint8_t flag; uint8_t forced; uint8_t illegal; }; static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = { [CLIENT_HELLO] = { { .mt = SERVER_HELLO_RETRY_REQUEST, }, { .mt = SERVER_HELLO, .flag = WITHOUT_HRR, }, }, [SERVER_HELLO_RETRY_REQUEST] = { { .mt = CLIENT_HELLO_RETRY, }, }, [CLIENT_HELLO_RETRY] = { { .mt = SERVER_HELLO, }, }, [SERVER_HELLO] = { { .mt = SERVER_ENCRYPTED_EXTENSIONS, }, }, [SERVER_ENCRYPTED_EXTENSIONS] = { { .mt = SERVER_CERTIFICATE_REQUEST, }, { .mt = SERVER_CERTIFICATE, .flag = WITHOUT_CR, }, { .mt = SERVER_FINISHED, .flag = WITH_PSK, }, }, [SERVER_CERTIFICATE_REQUEST] = { { .mt = SERVER_CERTIFICATE, }, }, [SERVER_CERTIFICATE] = { { .mt = SERVER_CERTIFICATE_VERIFY, }, }, [SERVER_CERTIFICATE_VERIFY] = { { .mt = SERVER_FINISHED, }, }, [SERVER_FINISHED] = { { .mt = CLIENT_FINISHED, .forced = WITHOUT_CR | WITH_PSK, }, { .mt = CLIENT_CERTIFICATE, .illegal = WITHOUT_CR | WITH_PSK, }, }, [CLIENT_CERTIFICATE] = { { .mt = CLIENT_FINISHED, }, { .mt = CLIENT_CERTIFICATE_VERIFY, .flag = WITH_CCV, }, }, [CLIENT_CERTIFICATE_VERIFY] = { { .mt = CLIENT_FINISHED, }, }, [CLIENT_FINISHED] = { { .mt = APPLICATION_DATA, }, }, [APPLICATION_DATA] = { { .mt = 0, }, }, }; const size_t stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]); void build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], struct child current, struct child end, struct child path[], uint8_t flags, unsigned int depth); size_t count_handshakes(void); void edge(enum tls13_message_type start, enum tls13_message_type end, uint8_t flag); const char *flag2str(uint8_t flag); void flag_label(uint8_t flag); void forced_edges(enum tls13_message_type start, enum tls13_message_type end, uint8_t forced); int generate_graphics(void); void fprint_entry(FILE *stream, enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags); void fprint_flags(FILE *stream, uint8_t flags); const char *mt2str(enum tls13_message_type mt); void usage(void); int verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print); const char * flag2str(uint8_t flag) { const char *ret; if (flag & (flag - 1)) errx(1, "more than one bit is set"); switch (flag) { case INITIAL: ret = "INITIAL"; break; case NEGOTIATED: ret = "NEGOTIATED"; break; case WITHOUT_CR: ret = "WITHOUT_CR"; break; case WITHOUT_HRR: ret = "WITHOUT_HRR"; break; case WITH_PSK: ret = "WITH_PSK"; break; case WITH_CCV: ret = "WITH_CCV"; break; case WITH_0RTT: ret = "WITH_0RTT"; break; default: ret = "UNKNOWN"; } return ret; } const char * mt2str(enum tls13_message_type mt) { const char *ret; switch (mt) { case INVALID: ret = "INVALID"; break; case CLIENT_HELLO: ret = "CLIENT_HELLO"; break; case CLIENT_HELLO_RETRY: ret = "CLIENT_HELLO_RETRY"; break; case CLIENT_END_OF_EARLY_DATA: ret = "CLIENT_END_OF_EARLY_DATA"; break; case CLIENT_CERTIFICATE: ret = "CLIENT_CERTIFICATE"; break; case CLIENT_CERTIFICATE_VERIFY: ret = "CLIENT_CERTIFICATE_VERIFY"; break; case CLIENT_FINISHED: ret = "CLIENT_FINISHED"; break; case SERVER_HELLO: ret = "SERVER_HELLO"; break; case SERVER_HELLO_RETRY_REQUEST: ret = "SERVER_HELLO_RETRY_REQUEST"; break; case SERVER_ENCRYPTED_EXTENSIONS: ret = "SERVER_ENCRYPTED_EXTENSIONS"; break; case SERVER_CERTIFICATE: ret = "SERVER_CERTIFICATE"; break; case SERVER_CERTIFICATE_VERIFY: ret = "SERVER_CERTIFICATE_VERIFY"; break; case SERVER_CERTIFICATE_REQUEST: ret = "SERVER_CERTIFICATE_REQUEST"; break; case SERVER_FINISHED: ret = "SERVER_FINISHED"; break; case APPLICATION_DATA: ret = "APPLICATION_DATA"; break; case TLS13_NUM_MESSAGE_TYPES: ret = "TLS13_NUM_MESSAGE_TYPES"; break; default: ret = "UNKNOWN"; break; } return ret; } void fprint_flags(FILE *stream, uint8_t flags) { int first = 1, i; if (flags == 0) { fprintf(stream, "%s", flag2str(flags)); return; } for (i = 0; i < 8; i++) { uint8_t set = flags & (1U << i); if (set) { fprintf(stream, "%s%s", first ? "" : " | ", flag2str(set)); first = 0; } } } void fprint_entry(FILE *stream, enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags) { int i; fprintf(stream, "\t["); fprint_flags(stream, flags); fprintf(stream, "] = {\n"); for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) { if (path[i] == 0) break; fprintf(stream, "\t\t%s,\n", mt2str(path[i])); } fprintf(stream, "\t},\n"); } void edge(enum tls13_message_type start, enum tls13_message_type end, uint8_t flag) { printf("\t%s -> %s", mt2str(start), mt2str(end)); flag_label(flag); printf(";\n"); } void flag_label(uint8_t flag) { if (flag) printf(" [label=\"%s\"]", flag2str(flag)); } void forced_edges(enum tls13_message_type start, enum tls13_message_type end, uint8_t forced) { uint8_t forced_flag, i; if (forced == 0) return; for (i = 0; i < 8; i++) { forced_flag = forced & (1U << i); if (forced_flag) edge(start, end, forced_flag); } } int generate_graphics(void) { enum tls13_message_type start, end; unsigned int child; uint8_t flag; uint8_t forced; printf("digraph G {\n"); printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO)); printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA)); for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) { for (child = 0; stateinfo[start][child].mt != 0; child++) { end = stateinfo[start][child].mt; flag = stateinfo[start][child].flag; forced = stateinfo[start][child].forced; if (forced == 0) edge(start, end, flag); else forced_edges(start, end, forced); } } printf("}\n"); return 0; } extern enum tls13_message_type handshakes[][TLS13_NUM_MESSAGE_TYPES]; extern size_t handshake_count; size_t count_handshakes(void) { size_t ret = 0, i; for (i = 0; i < handshake_count; i++) { if (handshakes[i][0] != INVALID) ret++; } return ret; } void build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], struct child current, struct child end, struct child path[], uint8_t flags, unsigned int depth) { unsigned int i; if (depth >= TLS13_NUM_MESSAGE_TYPES - 1) errx(1, "recursed too deeply"); /* Record current node. */ path[depth++] = current; flags |= current.flag; /* If we haven't reached the end, recurse over the children. */ if (current.mt != end.mt) { for (i = 0; stateinfo[current.mt][i].mt != 0; i++) { struct child child = stateinfo[current.mt][i]; int forced = stateinfo[current.mt][i].forced; int illegal = stateinfo[current.mt][i].illegal; if ((forced == 0 || (forced & flags)) && (illegal == 0 || !(illegal & flags))) build_table(table, child, end, path, flags, depth); } return; } if (flags == 0) errx(1, "path does not set flags"); if (table[flags][0] != 0) errx(1, "path traversed twice"); for (i = 0; i < depth; i++) table[flags][i] = path[i].mt; } int verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print) { int success = 1, i; size_t num_valid, num_found = 0; uint8_t flags = 0; do { if (table[flags][0] == 0) continue; num_found++; for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) { if (table[flags][i] != handshakes[flags][i]) { fprintf(stderr, "incorrect entry %d of handshake ", i); fprint_flags(stderr, flags); fprintf(stderr, "\n"); success = 0; } } if (print) fprint_entry(stdout, table[flags], flags); } while(++flags != 0); num_valid = count_handshakes(); if (num_valid != num_found) { fprintf(stderr, "incorrect number of handshakes: want %zu, got %zu.\n", num_valid, num_found); success = 0; } return success; } void usage(void) { fprintf(stderr, "usage: handshake_table [-C | -g]\n"); exit(1); } int main(int argc, char *argv[]) { static enum tls13_message_type hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = { [INITIAL] = { CLIENT_HELLO, SERVER_HELLO_RETRY_REQUEST, CLIENT_HELLO_RETRY, SERVER_HELLO, }, }; struct child start = { .mt = CLIENT_HELLO, }; struct child end = { .mt = APPLICATION_DATA, }; struct child path[TLS13_NUM_MESSAGE_TYPES] = {{0}}; uint8_t flags = NEGOTIATED; unsigned int depth = 0; int ch, graphviz = 0, print = 0; while ((ch = getopt(argc, argv, "Cg")) != -1) { switch (ch) { case 'C': print = 1; break; case 'g': graphviz = 1; break; default: usage(); } } argc -= optind; argv += optind; if (argc != 0) usage(); if (graphviz && print) usage(); if (graphviz) return generate_graphics(); build_table(hs_table, start, end, path, flags, depth); if (!verify_table(hs_table, print)) return 1; return 0; }