1/*
2 * Copyright (c) 2006 - 2010 Kungliga Tekniska Högskolan
3 * (Royal Institute of Technology, Stockholm, Sweden).
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * 3. Neither the name of KTH nor the names of its contributors may be
18 *    used to endorse or promote products derived from this software without
19 *    specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY KTH AND ITS CONTRIBUTORS ``AS IS'' AND ANY
22 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
24 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL KTH OR ITS CONTRIBUTORS BE
25 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
28 * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
29 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
30 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
31 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34#include "config.h"
35
36#include <stdio.h>
37#include <gssapi.h>
38#include <gssapi_scram.h>
39#include <gssapi_spi.h>
40#include <err.h>
41#include <roken.h>
42#include <hex.h>
43#include <getarg.h>
44#include "test_common.h"
45
46static char *user_string = NULL;
47static char *password_string = NULL;
48
49#ifdef ENABLE_SCRAM
50
51#include <heimscram.h>
52
53/*
54 *
55 */
56
57static unsigned int giterations = 1000;
58static heim_scram_data gsalt = {
59    .data = rk_UNCONST("salt"),
60    .length = 4
61};
62
63static int
64param(void *ctx,
65      const heim_scram_data *user,
66      heim_scram_data *salt,
67      unsigned int *iteration,
68      heim_scram_data *servernonce)
69{
70    if (user->length != strlen(user_string) && memcmp(user->data, user_string, user->length) != 0)
71	return ENOENT;
72
73    *iteration = giterations;
74
75    salt->data = malloc(gsalt.length);
76    memcpy(salt->data, gsalt.data, gsalt.length);
77    salt->length = gsalt.length;
78
79    servernonce->data = NULL;
80    servernonce->length = 0;
81
82    return 0;
83}
84
85static int
86calculate(void *ctx,
87	  heim_scram_method method,
88	  const heim_scram_data *user,
89	  const heim_scram_data *c1,
90	  const heim_scram_data *s1,
91	  const heim_scram_data *c2noproof,
92	  const heim_scram_data *proof,
93	  heim_scram_data *server,
94	  heim_scram_data *sessionKey)
95{
96    heim_scram_data client_key, client_key2, stored_key, server_key, clientSig;
97    int ret;
98
99    memset(&client_key2, 0, sizeof(client_key2));
100
101    ret = heim_scram_stored_key(method,
102				password_string, giterations, &gsalt,
103				&client_key, &stored_key, &server_key);
104    if (ret)
105	return ret;
106
107    ret = heim_scram_generate(method, &stored_key, &server_key,
108			      c1, s1, c2noproof, &clientSig, server);
109    heim_scram_data_free(&server_key);
110    if (ret)
111	goto out;
112
113    ret = heim_scram_validate_client_signature(method,
114					       &stored_key,
115					       &clientSig,
116					       proof,
117					       &client_key2);
118    if (ret)
119	goto out;
120
121
122    /* extra check since we know the client key */
123    if (client_key2.length != client_key.length ||
124	memcmp(client_key.data, client_key2.data, client_key.length) != 0) {
125	ret = EINVAL;
126	goto out;
127    }
128
129    ret = heim_scram_session_key(method,
130				 &stored_key,
131				 &client_key,
132				 c1, s1, c2noproof, sessionKey);
133    if (ret)
134	goto out;
135
136 out:
137    heim_scram_data_free(&stored_key);
138    heim_scram_data_free(&client_key);
139    heim_scram_data_free(&client_key2);
140
141    return ret;
142}
143
144static struct heim_scram_server server_proc = {
145    .version = SCRAM_SERVER_VERSION_1,
146    .param = param,
147    .calculate = calculate
148};
149
150static gss_cred_id_t client_cred = GSS_C_NO_CREDENTIAL;
151
152static void
153ac_complete(void *ctx, OM_uint32 major, gss_status_id_t status,
154	    gss_cred_id_t cred, gss_OID_set oids, OM_uint32 time_rec)
155{
156    OM_uint32 junk;
157
158    if (major) {
159	fprintf(stderr, "error: %d", (int)major);
160	gss_release_cred(&junk, &cred);
161	goto out;
162    }
163
164    client_cred = cred;
165
166 out:
167    gss_release_oid_set(&junk, &oids);
168}
169
170
171static int
172test_scram(const char *test_name, const char *user, const char *password)
173{
174    gss_name_t cname, target = GSS_C_NO_NAME;
175    OM_uint32 maj_stat, min_stat;
176    gss_ctx_id_t ctx = GSS_C_NO_CONTEXT;
177    gss_buffer_desc cn, input, output, output2;
178    int ret;
179    heim_scram *scram = NULL;
180    heim_scram_data in, out;
181    gss_auth_identity_desc identity;
182
183
184    memset(&identity, 0, sizeof(identity));
185
186    identity.username = rk_UNCONST(user);
187    identity.realm = "";
188    identity.password = rk_UNCONST(password);
189
190    cn.value = rk_UNCONST(user);
191    cn.length = strlen(user);
192
193    maj_stat = gss_import_name(&min_stat, &cn, GSS_C_NT_USER_NAME, &cname);
194    if (maj_stat)
195	errx(1, "gss_import_name: %d", (int)maj_stat);
196
197    maj_stat = gss_acquire_cred_ex_f(NULL,
198				     cname,
199				     0,
200				     GSS_C_INDEFINITE,
201				     GSS_SCRAM_MECHANISM,
202				     GSS_C_INITIATE,
203				     &identity,
204				     NULL,
205				     ac_complete);
206    if (maj_stat)
207	errx(1, "gss_acquire_cred_ex_f: %d", (int)maj_stat);
208
209    if (client_cred == GSS_C_NO_CREDENTIAL)
210	errx(1, "gss_acquire_cred_ex_f");
211
212    cn.value = rk_UNCONST("host@localhost");
213    cn.length = strlen((char *)cn.value);
214
215    maj_stat = gss_import_name(&min_stat, &cn,
216			       GSS_C_NT_HOSTBASED_SERVICE, &target);
217    if (maj_stat)
218	errx(1, "gss_import_name: %d", (int)maj_stat);
219
220    maj_stat = gss_init_sec_context(&min_stat, client_cred, &ctx,
221				    target, GSS_SCRAM_MECHANISM,
222				    0, 0, NULL,
223				    GSS_C_NO_BUFFER, NULL,
224				    &output, NULL, NULL);
225    if (maj_stat != GSS_S_CONTINUE_NEEDED)
226	errx(1, "accept_sec_context %s %s", test_name,
227	      gssapi_err(maj_stat, min_stat, GSS_C_NO_OID));
228
229    if (output.length == 0)
230	errx(1, "output.length == 0");
231
232    maj_stat = gss_decapsulate_token(&output, GSS_SCRAM_MECHANISM, &output2);
233    if (maj_stat)
234	errx(1, "decapsulate token");
235
236    in.length = output2.length;
237    in.data = output2.value;
238
239    ret = heim_scram_server1(&in, NULL, HEIM_SCRAM_DIGEST_SHA1, &server_proc, NULL, &scram, &out);
240    if (ret)
241	errx(1, "heim_scram_server1");
242
243    gss_release_buffer(&min_stat, &output);
244
245    input.length = out.length;
246    input.value = out.data;
247
248    maj_stat = gss_init_sec_context(&min_stat, client_cred, &ctx,
249				    target, GSS_SCRAM_MECHANISM,
250				    0, 0, NULL,
251				    &input, NULL,
252				    &output, NULL, NULL);
253    if (maj_stat != GSS_S_CONTINUE_NEEDED) {
254	warnx("accept_sec_context v1 2 %s",
255	     gssapi_err(maj_stat, min_stat, GSS_C_NO_OID));
256	return 1;
257    }
258
259    in.length = output.length;
260    in.data = output.value;
261
262    ret = heim_scram_server2(&in, scram, &out);
263    if (ret)
264	errx(1, "heim_scram_server2");
265
266    gss_release_buffer(&min_stat, &output);
267
268    input.length = out.length;
269    input.value = out.data;
270
271    maj_stat = gss_init_sec_context(&min_stat, client_cred, &ctx,
272				    target, GSS_SCRAM_MECHANISM,
273				    0, 0, NULL,
274				    &input, NULL,
275				    &output, NULL, NULL);
276    if (maj_stat != GSS_S_COMPLETE) {
277	warnx("accept_sec_context v1 2 %s",
278	     gssapi_err(maj_stat, min_stat, GSS_C_NO_OID));
279	return 1;
280    }
281
282    heim_scram_free(scram);
283
284    //gss_destroy_cred(NULL, &client_cred);
285
286    printf("done: %s\n", test_name);
287
288    return 0;
289}
290
291#endif /* ENABLE_SCRAM */
292
293/*
294 *
295 */
296
297static int version_flag = 0;
298static int help_flag	= 0;
299
300static struct getargs args[] = {
301    {"user",	0,	arg_string,	&user_string, "user name", "user" },
302    {"password",0,	arg_string,	&password_string, "password", "password" },
303    {"version",	0,	arg_flag,	&version_flag, "print version", NULL },
304    {"help",	0,	arg_flag,	&help_flag,  NULL, NULL }
305};
306
307static void
308usage (int ret)
309{
310    arg_printusage (args, sizeof(args)/sizeof(*args),
311		    NULL, "");
312    exit (ret);
313}
314
315int
316main(int argc, char **argv)
317{
318    int ret = 0, optidx = 0;
319
320    setprogname(argv[0]);
321
322    if(getarg(args, sizeof(args) / sizeof(args[0]), argc, argv, &optidx))
323	usage(1);
324
325    if (help_flag)
326	usage(0);
327
328    if(version_flag){
329	print_version(NULL);
330	exit(0);
331    }
332
333    if (user_string == NULL)
334	errx(1, "no username");
335    if (password_string == NULL)
336	errx(1, "no password");
337
338#ifdef ENABLE_SCRAM
339    ret += test_scram("scram", user_string, password_string);
340#endif
341
342    return (ret != 0) ? 1 : 0;
343}
344