ssl_init.c revision 316722
1/*
2 * ssl_init.c	Common OpenSSL initialization code for the various
3 *		programs which use it.
4 *
5 * Moved from ntpd/ntp_crypto.c crypto_setup()
6 */
7#ifdef HAVE_CONFIG_H
8#include <config.h>
9#endif
10#include <ctype.h>
11#include <ntp.h>
12#include <ntp_debug.h>
13#include <lib_strbuf.h>
14
15#ifdef OPENSSL
16#include "openssl/crypto.h"
17#include "openssl/err.h"
18#include "openssl/evp.h"
19#include "openssl/opensslv.h"
20#include "libssl_compat.h"
21
22int ssl_init_done;
23
24#if OPENSSL_VERSION_NUMBER < 0x10100000L
25
26static void
27atexit_ssl_cleanup(void)
28{
29	if (!ssl_init_done)
30		return;
31
32	ssl_init_done = FALSE;
33	EVP_cleanup();
34	ERR_free_strings();
35}
36
37void
38ssl_init(void)
39{
40	init_lib();
41
42	if ( ! ssl_init_done) {
43	    ERR_load_crypto_strings();
44	    OpenSSL_add_all_algorithms();
45	    atexit(&atexit_ssl_cleanup);
46	    ssl_init_done = TRUE;
47	}
48}
49
50#else /* OPENSSL_VERSION_NUMBER >= 0x10100000L */
51
52void
53ssl_init(void)
54{
55	init_lib();
56	ssl_init_done = TRUE;
57}
58
59#endif /* OPENSSL_VERSION_NUMBER */
60
61
62void
63ssl_check_version(void)
64{
65	u_long	v;
66
67	v = OpenSSL_version_num();
68	if ((v ^ OPENSSL_VERSION_NUMBER) & ~0xff0L) {
69		msyslog(LOG_WARNING,
70		    "OpenSSL version mismatch. Built against %lx, you have %lx",
71		    (u_long)OPENSSL_VERSION_NUMBER, v);
72		fprintf(stderr,
73		    "OpenSSL version mismatch. Built against %lx, you have %lx\n",
74		    (u_long)OPENSSL_VERSION_NUMBER, v);
75	}
76
77	INIT_SSL();
78}
79
80#endif	/* OPENSSL */
81
82
83/*
84 * keytype_from_text	returns OpenSSL NID for digest by name, and
85 *			optionally the associated digest length.
86 *
87 * Used by ntpd authreadkeys(), ntpq and ntpdc keytype()
88 */
89int
90keytype_from_text(
91	const char *text,
92	size_t *pdigest_len
93	)
94{
95	int		key_type;
96	u_int		digest_len;
97#ifdef OPENSSL
98	const u_long	max_digest_len = MAX_MAC_LEN - sizeof(keyid_t);
99	u_char		digest[EVP_MAX_MD_SIZE];
100	char *		upcased;
101	char *		pch;
102
103	/*
104	 * OpenSSL digest short names are capitalized, so uppercase the
105	 * digest name before passing to OBJ_sn2nid().  If it is not
106	 * recognized but begins with 'M' use NID_md5 to be consistent
107	 * with past behavior.
108	 */
109	INIT_SSL();
110	LIB_GETBUF(upcased);
111	strlcpy(upcased, text, LIB_BUFLENGTH);
112	for (pch = upcased; '\0' != *pch; pch++)
113		*pch = (char)toupper((unsigned char)*pch);
114	key_type = OBJ_sn2nid(upcased);
115#else
116	key_type = 0;
117#endif
118
119	if (!key_type && 'm' == tolower((unsigned char)text[0]))
120		key_type = NID_md5;
121
122	if (!key_type)
123		return 0;
124
125	if (NULL != pdigest_len) {
126#ifdef OPENSSL
127		EVP_MD_CTX	*ctx;
128
129		ctx = EVP_MD_CTX_new();
130		EVP_DigestInit(ctx, EVP_get_digestbynid(key_type));
131		EVP_DigestFinal(ctx, digest, &digest_len);
132		EVP_MD_CTX_free(ctx);
133		if (digest_len > max_digest_len) {
134			fprintf(stderr,
135				"key type %s %u octet digests are too big, max %lu\n",
136				keytype_name(key_type), digest_len,
137				max_digest_len);
138			msyslog(LOG_ERR,
139				"key type %s %u octet digests are too big, max %lu",
140				keytype_name(key_type), digest_len,
141				max_digest_len);
142			return 0;
143		}
144#else
145		digest_len = 16;
146#endif
147		*pdigest_len = digest_len;
148	}
149
150	return key_type;
151}
152
153
154/*
155 * keytype_name		returns OpenSSL short name for digest by NID.
156 *
157 * Used by ntpq and ntpdc keytype()
158 */
159const char *
160keytype_name(
161	int nid
162	)
163{
164	static const char unknown_type[] = "(unknown key type)";
165	const char *name;
166
167#ifdef OPENSSL
168	INIT_SSL();
169	name = OBJ_nid2sn(nid);
170	if (NULL == name)
171		name = unknown_type;
172#else	/* !OPENSSL follows */
173	if (NID_md5 == nid)
174		name = "MD5";
175	else
176		name = unknown_type;
177#endif
178	return name;
179}
180
181
182/*
183 * Use getpassphrase() if configure.ac detected it, as Suns that
184 * have it truncate the password in getpass() to 8 characters.
185 */
186#ifdef HAVE_GETPASSPHRASE
187# define	getpass(str)	getpassphrase(str)
188#endif
189
190/*
191 * getpass_keytype() -- shared between ntpq and ntpdc, only vaguely
192 *			related to the rest of ssl_init.c.
193 */
194char *
195getpass_keytype(
196	int	keytype
197	)
198{
199	char	pass_prompt[64 + 11 + 1]; /* 11 for " Password: " */
200
201	snprintf(pass_prompt, sizeof(pass_prompt),
202		 "%.64s Password: ", keytype_name(keytype));
203
204	return getpass(pass_prompt);
205}
206