1#include <stdio.h>
2#include <string.h>
3
4#include <barrelfish/barrelfish.h>
5#include <barrelfish/nameservice_client.h>
6#include <if/mt_waitset_defs.h>
7#include <if/mt_waitset_defs.h>
8#include <barrelfish/deferred.h>
9#include <barrelfish/inthandler.h>
10#include <bench/bench.h>
11#include <sys/time.h>
12#include "../lib/barrelfish/include/threads_priv.h"
13#include <barrelfish/debug.h>
14#include <barrelfish/spawn_client.h>
15#include <barrelfish/event_mutex.h>
16
17const static char *service_name = "mt_waitset_service";
18coreid_t my_core_id, num_cores;
19struct thread *threads[1024];
20
21static int server_threads = 10;
22static int client_threads = 1;
23static int iteration_count = 1000;
24
25static int client_counter = 0;
26static int64_t server_calls[1024];
27static int64_t client_calls[1024][1024];
28
29#ifdef __x86_64__
30#define START_ON_ALL_CORES true
31#else
32#define START_ON_ALL_CORES false
33#endif
34
35static void show_stats(void)
36{
37    debug_printf("Stats: %zd %zd %zd %zd %zd %zd %zd %zd %zd %zd\n",
38        server_calls[0], server_calls[1], server_calls[2], server_calls[3],
39        server_calls[4], server_calls[5], server_calls[6], server_calls[7],
40        server_calls[8], server_calls[9]);
41}
42
43static void show_client_stats(void)
44{
45    int i, j, s;
46    char text[256];
47
48    for (i = 0; i < num_cores; i++) {
49        s = sprintf(text, "Core %d:", i);
50        for (j = 0; j < 16; j++)
51            s += sprintf(text + s, "\t%zd", client_calls[i][j]);
52        s += sprintf(text + s, "\n");
53        debug_printf("%s", text);
54    }
55}
56
57static int client_thread(void * arg)
58{
59    struct mt_waitset_binding *binding;
60    errval_t err;
61    binding = arg;
62    int i, j, k, l;
63    uint64_t payload[512];
64    uint64_t result[512];
65    size_t result_size;
66    uint64_t o1;
67    uint32_t o2;
68    uint32_t i1 = my_core_id << 8 | thread_self()->id;
69    uint64_t mmm = ((uint64_t)my_core_id << 56) | ((uint64_t)thread_self()->id << 48);
70
71    debug_printf("Start\n");
72
73    for (k = 0; k < iteration_count; k++) {
74        uint64_t i2 = (rdtsc() & 0xffffffff) | mmm | (((uint64_t)k & 0xffffL) << 32);
75
76        j = ((i2 >> 5) & 511) + 1;
77
78        i2 &= 0xfffffffffffff000;
79
80        for (i = 0; i < j; i++)
81            payload[i] = i2 + i;
82        err = binding->rpc_tx_vtbl.rpc_method(binding, i2, (uint8_t *)payload, 8 * j, i1, &o1, (uint8_t *)result, &result_size, &o2);
83
84        assert(err == SYS_ERR_OK);
85        l = 0;
86        for (i = 0; i < j; i++) {
87            if (result[i] == payload[i] + i)
88                l++;
89        }
90        if (!(i2 + 1 == o1) || result_size != (8 * j) || l != j) {
91            debug_printf("%d: wrong %016lx != %016lx  %d %zd    %d %d\n", k, i2 + 1, o1, 8 * j, result_size, j, l);
92            for (i = 0; i < j; i++)
93                debug_printf("\t%d: %016lx %016lx\n", i, payload[i], result[i]);
94        }
95        server_calls[o2]++;
96        if (err_is_fail(err)) {
97            DEBUG_ERR(err, "error sending message\n");
98        }
99    }
100
101    dispatcher_handle_t handle = disp_disable();
102
103    __sync_fetch_and_sub(&client_counter, 1);
104
105    debug_printf("Done, threads left:%d\n", client_counter);
106
107    if (client_counter == 0) {
108        disp_enable(handle);
109        // all threads have finished, we're done, inform the server
110        payload[0] = mmm;
111        err = binding->rpc_tx_vtbl.rpc_method(binding, mmm, (uint8_t *)payload, 8, 65536, &o1, (uint8_t *)result, &result_size, &o2);
112        show_stats();
113    } else
114        disp_enable(handle);
115    return 0;
116}
117
118static void bind_cb(void *st, errval_t err, struct mt_waitset_binding *b)
119{
120    int i = (long int)st;
121
122    mt_waitset_rpc_client_init(b);
123
124    client_counter = client_threads;
125    for (i = 1; i < client_threads; i++) {
126        threads[i] = thread_create(client_thread, b);
127        assert(threads[i]);
128    }
129
130    client_thread(b);
131
132    for (i = 1; i < client_threads; i++) {
133        int res;
134        err = thread_join(threads[i], &res);
135        assert(err_is_ok(err));
136    }
137
138    debug_printf("client done.\n");
139}
140
141static void start_client(void)
142{
143    char name[64];
144    errval_t err;
145    iref_t iref;
146
147    debug_printf("Start client\n");
148    sprintf(name, "%s%d", service_name, 0);
149    err = nameservice_blocking_lookup(service_name, &iref);
150    if (err_is_fail(err)) {
151        USER_PANIC_ERR(err, "nameservice_blocking_lookup failed");
152    }
153    err = mt_waitset_bind(iref, bind_cb,  (void *)0, get_default_waitset(), IDC_BIND_FLAGS_DEFAULT);
154    if (err_is_fail(err)) {
155        USER_PANIC_ERR(err, "bind failed");
156    }
157}
158
159
160// server
161
162static void export_cb(void *st, errval_t err, iref_t iref)
163{
164    if (err_is_fail(err)) {
165        USER_PANIC_ERR(err, "export failed");
166    }
167    err = nameservice_register(service_name, iref);
168    if (err_is_fail(err)) {
169            USER_PANIC_ERR(err, "nameservice_register failed");
170    }
171}
172
173static errval_t server_rpc_method_call(struct mt_waitset_binding *b, uint64_t i1, const uint8_t *s, size_t ss, uint32_t i2, uint64_t *o1, uint8_t *r, size_t *rs, uint32_t *o2)
174{
175    int i, j, k, me;
176    static int count = 0;
177    static uint64_t calls = 0;
178    uint64_t *response = (uint64_t *)r;
179
180    for (i = 0;; i++) {
181        if (thread_self() == threads[i]) {
182            server_calls[i]++;
183            me = i;
184            break;
185        }
186    }
187
188    if (i2 == 65536) {
189        __sync_fetch_and_add(&count, 1);    // client has finished
190    } else
191        client_calls[i2 >> 8][i2 & 255]++;
192
193    j = ss / 8;
194    k = 0;
195    for (i = 0; i < j; i++) {
196        response[i] = ((uint64_t *)s)[i];
197        if (response[i] == i1 + i)
198            k++;
199        response[i] += i;
200    }
201    if (k != j && i2 != 65536)
202        debug_printf("%s: binding:%p %08x %08x  %d %d   %016lx:%d\n", __func__, b, i2, b->incoming_token, k, j, response[0], me);
203#if START_ON_ALL_CORES
204    if (count == num_cores) {
205#else
206    if (count == num_cores - 1) {
207#endif
208        bool failed = false;
209
210        debug_printf("Final statistics\n");
211        show_stats();
212        show_client_stats();
213        for (i = 0; i < num_cores; i++) {
214            #if !START_ON_ALL_CORES
215            if (i == my_core_id) {
216                continue;
217            }
218            #endif
219            for (j = 0; j < client_threads; j++) {
220                if (client_calls[i][j] != iteration_count) {
221                    failed = true;
222                    goto out;
223                }
224            }
225        }
226out:
227        if (failed)
228            debug_printf("Test FAILED\n");
229        else
230            debug_printf("Test PASSED\n");
231    }
232    calls++;
233    if ((calls % 10000) == 0) {
234        show_stats();
235    }
236
237    *o1 = i1 + 1;
238    *rs = 8 * j;
239    *o2 = me;
240
241    return SYS_ERR_OK;
242}
243
244static struct mt_waitset_rpc_rx_vtbl rpc_rx_vtbl = {
245    .rpc_method_call = server_rpc_method_call
246};
247
248static errval_t connect_cb(void *st, struct mt_waitset_binding *b)
249{
250    b->rpc_rx_vtbl = rpc_rx_vtbl;
251    return SYS_ERR_OK;
252}
253
254static int run_server(void * arg)
255{
256    int i = (uintptr_t)arg;
257    struct waitset *ws = get_default_waitset();
258    errval_t err;
259
260
261    debug_printf("Server dispatch loop %d\n", i);
262    threads[i] = thread_self();
263
264    for (;;) {
265        err = event_dispatch(ws);
266        if (err_is_fail(err)) {
267            DEBUG_ERR(err, "in event_dispatch");
268            break;
269        }
270    }
271    return SYS_ERR_OK;
272}
273
274static void start_server(void)
275{
276    struct waitset *ws = get_default_waitset();
277    errval_t err;
278    int i;
279
280    debug_printf("Start server\n");
281
282    err = mt_waitset_export(NULL, export_cb, connect_cb, ws,
283                            IDC_EXPORT_FLAGS_DEFAULT);
284    if (err_is_fail(err)) {
285        USER_PANIC_ERR(err, "export failed");
286    }
287    for (i = 1; i < server_threads; i++) {
288        thread_create(run_server, (void *)(uintptr_t)i);
289    }
290}
291
292int main(int argc, char *argv[])
293{
294    errval_t err;
295    char *my_name = strdup(argv[0]);
296
297    my_core_id = disp_get_core_id();
298
299    memset(server_calls, 0, sizeof(server_calls));
300    memset(client_calls, 0, sizeof(client_calls));
301
302    if (argc == 1) {
303        debug_printf("Usage: %s server_threads client_threads iteration_count\n", argv[0]);
304    } else if (argc == 4) {
305        char *xargv[] = {my_name, argv[2], argv[3], NULL};
306
307        server_threads = atoi(argv[1]);
308        client_threads = atoi(argv[2]);
309        iteration_count = atoi(argv[3]);
310
311        #if !START_ON_ALL_CORES
312        debug_printf("XXX: disabling starting on the same core\n");
313        #endif
314
315        err = spawn_program_on_all_cores(START_ON_ALL_CORES, xargv[0], xargv, NULL,
316            SPAWN_FLAGS_DEFAULT, NULL, &num_cores);
317        debug_printf("spawn program on all cores (%d)\n", num_cores);
318        assert(err_is_ok(err));
319
320        #if !START_ON_ALL_CORES
321        num_cores += 1;
322        #endif
323
324
325        start_server();
326
327        struct waitset *ws = get_default_waitset();
328
329        threads[0] = thread_self();
330        for (;;) {
331            err = event_dispatch(ws);
332            if (err_is_fail(err)) {
333                DEBUG_ERR(err, "in event_dispatch");
334                break;
335            }
336        }
337    } else {
338        client_threads = atoi(argv[1]);
339        iteration_count = atoi(argv[2]);
340
341        struct waitset *ws = get_default_waitset();
342        start_client();
343        debug_printf("Client process events\n");
344        for (;;) {
345            err = event_dispatch(ws);
346            if (err_is_fail(err)) {
347                DEBUG_ERR(err, "in event_dispatch");
348                break;
349            }
350        }
351    }
352    return EXIT_FAILURE;
353}
354