1/*
2 * Copyright 2016, Data61
3 * Commonwealth Scientific and Industrial Research Organisation (CSIRO)
4 * ABN 41 687 119 230.
5 *
6 * This software may be distributed and modified according to the terms of
7 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
8 * See "LICENSE_BSD2.txt" for details.
9 *
10 * @TAG(D61_BSD)
11 */
12
13#include <refos-rpc/rpc.h>
14#include <refos/refos.h>
15#include <refos/vmlayout.h>
16#include <refos-util/dprintf.h>
17
18#define ROUND_UP(N, S) ((((N) + (S) - 1) / (S)) * (S))
19#define ROUND_DOWN(N, S) (((N) / (S)) * (S))
20
21// Static memory pool to allocate from for IPC. This is needed as normal malloc() might itself
22// require an RPC, resulting in unexpected behaviour.
23#define RPC_STATIC_MEMPOOL_OBJ_SIZE 4096
24static char _rpc_static_mempool[RPC_MAX_TRACKED_OBJS][RPC_STATIC_MEMPOOL_OBJ_SIZE];
25static bool _rpc_static_mempool_table[RPC_MAX_TRACKED_OBJS];
26
27// Current global MR and cap index, used for setmr and getmr.
28uint32_t _rpc_mr;
29uint32_t _rpc_cp;
30
31// Other global rpc state.
32static seL4_CPtr _rpc_recv_cslot;
33ENDPT _rpc_dest_ep;
34seL4_MessageInfo_t _rpc_minfo;
35uint32_t _rpc_label;
36const char* _rpc_name;
37
38// ------------------------------------------- RPC Helper ------------------------------------------
39
40void*
41rpc_malloc(size_t sz)
42{
43    // Minimal static buffer pool allocation.
44    // Note that we cannot malloc here, as malloc could call mmap which could call us back,
45    // resulting in a cyclic dependency.
46    assert(sz <= RPC_STATIC_MEMPOOL_OBJ_SIZE);
47    int i;
48    for (i = 0; i < RPC_MAX_TRACKED_OBJS; i++) {
49        if (!_rpc_static_mempool_table[i]) {
50            break;
51        }
52    }
53    assert(i < RPC_MAX_TRACKED_OBJS);
54    _rpc_static_mempool_table[i] = true;
55    return _rpc_static_mempool[i];
56}
57
58void
59rpc_free(void *addr)
60{
61    int i = (((char*)addr) - (&_rpc_static_mempool[0][0])) / RPC_STATIC_MEMPOOL_OBJ_SIZE;
62    assert(i >= 0 && i < RPC_MAX_TRACKED_OBJS);
63    assert(_rpc_static_mempool_table[i]);
64    _rpc_static_mempool_table[i] = false;
65}
66
67uint32_t
68rpc_marshall(uint32_t cur_mr, const char *str, uint32_t slen)
69{
70    assert(str);
71    if (slen == 0) {
72        return cur_mr;
73    }
74
75    int i;
76    for (i = 0; i < ROUND_DOWN(slen, 4); i+=4, str+=4) {
77        seL4_SetMR(cur_mr++, *(seL4_Word*) str);
78    }
79    if (i != slen) {
80        seL4_Word w = 0;
81        memcpy(&w, str, slen - i);
82        seL4_SetMR(cur_mr++, w);
83    }
84
85    return cur_mr;
86}
87
88uint32_t
89rpc_unmarshall(uint32_t cur_mr, char *str, uint32_t slen)
90{
91    assert(str);
92    if (slen == 0) return cur_mr;
93
94    int i;
95    for (i = 0; i < ROUND_DOWN(slen, 4); i+=4, str+=4) {
96        *(seL4_Word*) str = seL4_GetMR(cur_mr++);
97    }
98    if (i != slen) {
99        seL4_Word w = seL4_GetMR(cur_mr++);
100        memcpy(str, &w, slen - i);
101    }
102
103    return cur_mr;
104}
105
106void
107rpc_setup_recv(seL4_CPtr recv_cslot)
108{
109	assert(recv_cslot);
110	seL4_SetCapReceivePath(REFOS_CSPACE, recv_cslot, REFOS_CSPACE_DEPTH);
111	_rpc_recv_cslot = recv_cslot;
112}
113
114void
115rpc_setup_recv_cspace(seL4_CPtr cspace, seL4_CPtr recv_cslot, seL4_Word depth)
116{
117    assert(recv_cslot);
118    seL4_SetCapReceivePath(cspace, recv_cslot, depth);
119    _rpc_recv_cslot = recv_cslot;
120}
121
122void
123rpc_reset_contents(void *cl)
124{
125    (void) cl;
126    _rpc_mr = 1;
127    _rpc_cp = 0;
128}
129
130// ------------------------------------------- Client RPC ------------------------------------------
131
132static seL4_CPtr
133rpc_get_endpoint(int32_t label)
134{
135    if (_rpc_dest_ep) return _rpc_dest_ep;
136    assert(!"rpc_get_endpoint: unknown label.");
137    return (seL4_CPtr)0;
138}
139
140void
141rpc_init(const char* name_str, int32_t label)
142{
143    _rpc_label = label;
144    _rpc_name = name_str;
145
146	rpc_reset_contents(NULL);
147
148    if (!_rpc_recv_cslot) {
149        rpc_setup_recv(REFOS_THREAD_CAP_RECV);
150    } else if (seL4_MessageInfo_get_extraCaps(_rpc_minfo) > 0) {
151        // Flush recieving path of previous recieved caps.
152        seL4_CNode_Delete(REFOS_CSPACE, _rpc_recv_cslot, REFOS_CDEPTH);
153    }
154
155    seL4_SetMR(0, label);
156}
157
158void
159rpc_push_uint(uint32_t v)
160{
161    seL4_SetMR(_rpc_mr++, v);
162}
163
164void
165rpc_push_str(const char* v)
166{
167    uint32_t slen = strlen(v);
168    rpc_push_uint(slen);
169    _rpc_mr = rpc_marshall(_rpc_mr, v, slen);
170}
171
172void
173rpc_push_buf(void* v, size_t sz)
174{
175    if (!sz) return;
176    if (!v) sz = 0;
177    _rpc_mr = rpc_marshall(_rpc_mr, v, sz);
178}
179
180void
181rpc_push_buf_array(void* v, size_t sz, uint32_t count)
182{
183    char *rv = (char*)v;
184    rpc_push_uint(count);
185    for (uint32_t i = 0; i < count; i++) {
186        rpc_push_buf(rv + i * sz, sz);
187    }
188}
189
190void
191rpc_push_cptr(ENDPT v)
192{
193	seL4_SetCap(_rpc_cp++, v);
194}
195
196void
197rpc_set_dest(ENDPT dest)
198{
199    _rpc_dest_ep = dest;
200}
201
202uint32_t
203rpc_pop_uint()
204{
205    return seL4_GetMR(_rpc_mr++);
206}
207
208void
209rpc_pop_str(char* v)
210{
211    // WARNING: Outputting to a C char string is never a safe thing to do.
212    uint32_t slen = rpc_pop_uint();
213    _rpc_mr = rpc_unmarshall(_rpc_mr, v, slen);
214    v[slen] = '\0';
215}
216
217void
218rpc_pop_buf(void* v, size_t sz)
219{
220    if (!sz) return;
221    assert(v);
222    _rpc_mr = rpc_unmarshall(_rpc_mr, v, sz);
223}
224
225ENDPT
226rpc_pop_cptr()
227{
228   assert(_rpc_recv_cslot);
229   if (seL4_MessageInfo_get_extraCaps(_rpc_minfo) < 1) {
230       //assert(!"RPC Failed to recieve the cap");
231       return 0;
232   }
233   return _rpc_recv_cslot;
234}
235
236void
237rpc_pop_buf_array(void* v, size_t sz, uint32_t count)
238{
239    uint32_t cn = rpc_pop_uint();
240    assert(cn <= count);
241    for (int i = 0; i < cn; i++) {
242        rpc_pop_buf(((char*)v) + (i * sz), sz);
243    }
244}
245
246int
247rpc_call_server()
248{
249    seL4_MessageInfo_t tag = seL4_MessageInfo_new(0, 0, _rpc_cp, _rpc_mr);
250    int ept = rpc_get_endpoint(_rpc_label);
251    _rpc_minfo = seL4_Call(ept, tag);
252    rpc_reset_contents(NULL);
253    return 0;
254}
255
256void
257rpc_release()
258{
259    _rpc_dest_ep = 0;
260}
261
262
263// ---------------------------------------------- Server RPC ---------------------------------------
264
265void
266rpc_sv_init(void *cl)
267{
268    rpc_reset_contents(cl);
269    if (!_rpc_recv_cslot) rpc_setup_recv(REFOS_THREAD_CAP_RECV);
270	if (!cl) {
271        return;
272    }
273    rpc_client_state_t* c = (rpc_client_state_t*)cl;
274    c->num_obj = 0;
275    c->skip_reply = false;
276}
277
278uint32_t
279rpc_sv_pop_uint(void *cl)
280{
281    (void)cl;
282    return seL4_GetMR(_rpc_mr++);
283}
284
285char*
286rpc_sv_pop_str(void *cl)
287{
288    uint32_t slen = rpc_sv_pop_uint(cl);
289    char *str = rpc_malloc((slen + 1) * sizeof(char));
290    assert(str);
291    _rpc_mr = rpc_unmarshall(_rpc_mr, str, slen);
292    str[slen] = '\0';
293    return str;
294}
295
296void
297rpc_sv_pop_buf(void *cl, void *v, size_t sz)
298{
299    if (!sz) return;
300    _rpc_mr = rpc_unmarshall(_rpc_mr, v, sz);
301}
302
303rpc_buffer_t
304rpc_sv_pop_buf_array(void *cl, size_t sz)
305{
306    uint32_t count = rpc_sv_pop_uint(cl);
307    char *v = rpc_malloc(count * sz);
308    for (uint32_t i = 0; i < count; i++) {
309        _rpc_mr = rpc_unmarshall(_rpc_mr, v + i * sz, sz);
310    }
311    rpc_buffer_t buffer;
312    buffer.data = v;
313    buffer.count = count;
314    return buffer;
315}
316
317ENDPT
318rpc_sv_pop_cptr(void *cl)
319{
320    rpc_client_state_t* c = (rpc_client_state_t*)cl;
321    if (_rpc_cp >= seL4_MessageInfo_get_extraCaps(c->minfo)) {
322        return 0;
323    }
324    seL4_Word unw = seL4_MessageInfo_get_capsUnwrapped(c->minfo);
325    if (unw & (1 << _rpc_cp)) {
326        return seL4_CapData_Badge_get_Badge(seL4_GetBadge(_rpc_cp++));
327    }
328    _rpc_cp++;
329    assert(_rpc_recv_cslot);
330    return _rpc_recv_cslot;
331}
332
333void
334rpc_sv_push_uint(void *cl, uint32_t v)
335{
336    (void)cl;
337    seL4_SetMR(_rpc_mr++, v);
338}
339
340void
341rpc_sv_push_buf(void *cl, void* v, size_t sz)
342{
343    if (!sz) return;
344    _rpc_mr = rpc_marshall(_rpc_mr, v, sz);
345}
346
347void
348rpc_sv_push_cptr(void *cl, ENDPT v)
349{
350    if (!v) return;
351    seL4_SetCap(_rpc_cp++, v);
352}
353
354void
355rpc_sv_push_buf_array(void *cl, rpc_buffer_t v, size_t sz)
356{
357    rpc_sv_push_uint(cl, v.count);
358    for (uint32_t i = 0; i < v.count; i++) {
359        rpc_sv_push_buf(cl, ((char*)(v.data)) + (i * sz), sz);
360    }
361}
362
363void
364rpc_sv_reply(void* cl)
365{
366    if (rpc_sv_skip_reply(cl)) return;
367    seL4_CPtr reply_endpoint = rpc_sv_get_reply_endpoint(cl);
368    seL4_MessageInfo_t reply = seL4_MessageInfo_new(0, 0, _rpc_cp, _rpc_mr);
369    if (reply_endpoint) {
370        seL4_Send(reply_endpoint, reply);
371    } else {
372        seL4_Reply(reply);
373    }
374}
375
376void
377rpc_sv_release(void *cl)
378{
379    rpc_client_state_t* c = (rpc_client_state_t*)cl;
380    (void)c;
381
382    _rpc_dest_ep = 0;
383
384    if (seL4_MessageInfo_get_extraCaps(c->minfo) > 0) {
385        // Flush recieving path of previous recieved caps.
386        seL4_CNode_Delete(REFOS_CSPACE, _rpc_recv_cslot, REFOS_CSPACE_DEPTH);
387    }
388}
389
390void
391rpc_sv_track_obj(void* cl, void* addr)
392{
393    rpc_client_state_t* c = (rpc_client_state_t*)cl;
394    assert(c->num_obj < RPC_MAX_TRACKED_OBJS - 1);
395    c->obj[c->num_obj++] = addr;
396}
397
398void
399rpc_sv_free_tracked_objs(void* cl)
400{
401    rpc_client_state_t* c = (rpc_client_state_t*)cl;
402    for (int i = 0; i < c->num_obj; i++) {
403        rpc_free(c->obj[i]);
404    }
405    c->num_obj = 0;
406}
407
408
409