1#include <mdb/mdb_tree.h>
2#include <mdb/mdb.h>
3#include <cap_predicates.h>
4#include <barrelfish_kpi/capabilities.h>
5#include <capabilities.h>
6#include <assert.h>
7#include <stdio.h>
8#if IN_KERNEL
9#include <kernel.h>
10#include <kcb.h>
11#endif
12
13#ifndef MIN
14#define MIN(a, b) ((a)<(b)?(a):(b))
15#endif
16#ifndef MAX
17#define MAX(a, b) ((a)>(b)?(a):(b))
18#endif
19
20#ifdef N
21#undef N
22#endif
23#define N(cte) (&(cte)->mdbnode)
24#ifdef C
25#undef C
26#endif
27#define C(cte) (&(cte)->cap)
28
29// define panic() for user-land build
30#ifndef IN_KERNEL
31#define panic(msg...) \
32    do { \
33        printf(msg); \
34        abort(); \
35    }while(0)
36#endif
37
38// PP switch to change behaviour if invariants fail
39#ifdef MDB_FAIL_INVARIANTS
40// on failure, dump mdb and terminate
41__attribute__((noreturn))
42static void
43mdb_dump_and_fail(struct cte *cte, enum mdb_invariant failure)
44{
45    mdb_dump(cte, 0);
46    panic("failed on cte %p with failure %s (%d)\n",
47          cte, mdb_invariant_to_str(failure), failure);
48}
49#define MDB_RET_INVARIANT(cte, failure) mdb_dump_and_fail(cte, failure)
50#else
51#define MDB_RET_INVARIANT(cte, failure) return failure
52#endif
53
54// PP switch to toggle top-level checking of invariants
55#ifdef MDB_CHECK_INVARIANTS
56#define CHECK_INVARIANTS(root, cte, reach) \
57do { \
58    if (mdb_is_reachable(root, cte) != reach) { \
59        panic("mdb_is_reachable(%p,%p) != %d", root, cte, reach); \
60    } \
61    mdb_check_subtree_invariants(cte); \
62} while(0)
63#else
64#define CHECK_INVARIANTS(root, cte, reach) ((void)0)
65#endif
66
67// PP switch to toggle recursive checking of invariants by default
68#ifdef MDB_RECHECK_INVARIANTS
69// disable toplevel invariants checks except for the assertion clause as we're
70// doing them anyway in CHECK_INVARIANTS_SUB
71#undef CHECK_INVARIANTS
72#define CHECK_INVARIANTS(root, cte, reach) \
73do { \
74    if (mdb_is_reachable(root, cte) != reach) { \
75        panic("mdb_is_reachable(%p,%p) != %d", root, cte, reach); \
76    } \
77} while(0)
78#define CHECK_INVARIANTS_SUB(cte) mdb_check_subtree_invariants(cte)
79#else
80#define CHECK_INVARIANTS_SUB(cte) ((void)0)
81#endif
82
83// printf tracing and entry/exit invariant checking
84#ifdef MDB_TRACE
85#define MDB_TRACE_ENTER(valid_cte, args_fmt, ...) do { \
86    printf("enter %s(" args_fmt ")\n", __func__, __VA_ARGS__); \
87    CHECK_INVARIANTS_SUB(valid_cte); \
88} while (0)
89#define MDB_TRACE_LEAVE_SUB(valid_cte) do { \
90    CHECK_INVARIANTS_SUB(valid_cte); \
91    printf("leave %s\n", __func__); \
92    return; \
93} while (0)
94#define MDB_TRACE_LEAVE_SUB_RET(ret_fmt, ret, valid_cte) do { \
95    CHECK_INVARIANTS_SUB(valid_cte); \
96    printf("leave %s->" ret_fmt "\n", __func__, (ret)); \
97    return (ret); \
98} while (0)
99#else
100#define MDB_TRACE_ENTER(valid_cte, args_fmt, ...) CHECK_INVARIANTS_SUB(valid_cte)
101#define MDB_TRACE_LEAVE_SUB(valid_cte) do { \
102    CHECK_INVARIANTS_SUB(valid_cte); \
103    return; \
104} while (0)
105#define MDB_TRACE_LEAVE_SUB_RET(ret_fmt, ret, valid_cte) do { \
106    CHECK_INVARIANTS_SUB(valid_cte); \
107    return (ret); \
108} while (0)
109#endif
110
111// Global for test_ops_with_root.c
112struct cte *mdb_root = NULL;
113#if IN_KERNEL
114struct kcb *my_kcb = NULL;
115#endif
116static void set_root(struct cte *new_root)
117{
118    mdb_root = new_root;
119#if IN_KERNEL
120    my_kcb->mdb_root = (lvaddr_t) new_root;
121#endif
122}
123
124/*
125 * (re)initialization
126 */
127errval_t
128mdb_init(struct kcb *k)
129{
130    assert (k != NULL);
131#if IN_KERNEL
132#if 0
133    //XXX: write two versions of this; so we can have full sanity checks for
134    //all scenarios -SG
135    if (my_kcb) {
136        printf("MDB has non-null kcb.\n");
137        return CAPS_ERR_MDB_ALREADY_INITIALIZED;
138    }
139#endif
140    my_kcb = k;
141    if (!my_kcb->is_valid) {
142        // empty kcb, do nothing
143        return SYS_ERR_OK;
144    }
145#endif
146    // set root
147    mdb_root = (struct cte *)k->mdb_root;
148
149#if 0
150    // always check invariants here
151    int i = mdb_check_invariants();
152    if (i) {
153        printf("mdb invariant %s violated\n", mdb_invariant_to_str(i));
154        mdb_dump_all_the_things();
155        mdb_root = NULL;
156        return CAPS_ERR_MDB_INVARIANT_VIOLATION;
157    }
158#endif
159    return SYS_ERR_OK;
160}
161
162
163/*
164 * Debug printing.
165 */
166
167void
168mdb_dump_all_the_things(void)
169{
170    mdb_dump(mdb_root, 0);
171}
172
173STATIC_ASSERT(50 == ObjType_Num, "Knowledge of all cap types");
174static void print_cte(struct cte *cte, char *indent_buff)
175{
176    struct mdbnode *node = N(cte);
177    char extra[255] = { 0 };
178    struct capability *cap = C(cte);
179    switch (cap->type) {
180        case ObjType_L1CNode:
181            snprintf(extra, 255,
182                    "[allocated_bytes=%"PRIxGENSIZE",rightsmask=%"PRIu8"]",
183                    cap->u.l1cnode.allocated_bytes, cap->u.l1cnode.rightsmask);
184            break;
185        case ObjType_L2CNode:
186            snprintf(extra, 255, "[rightsmask=%"PRIu8"]", cap->u.l2cnode.rightsmask);
187            break;
188        case ObjType_EndPoint:
189            snprintf(extra, 255,
190                    "[listener=%p,epoffset=0x%08"PRIxLVADDR",epbuflen=%"PRIu32"]",
191                    cap->u.endpoint.listener,cap->u.endpoint.epoffset,cap->u.endpoint.epbuflen);
192            break;
193        case ObjType_Dispatcher:
194            snprintf(extra, 255, "[dcb=%p]", cap->u.dispatcher.dcb);
195            break;
196        case ObjType_IO:
197            snprintf(extra, 255, "[start=0x%04"PRIx16",end=0x%04"PRIx16"]",
198                    cap->u.io.start, cap->u.io.end);
199            break;
200        default:
201            break;
202    }
203    printf("%s%p{left=%p,right=%p,end=0x%08"PRIxGENPADDR",end_root=%"PRIu8","
204            "level=%"PRIu8",address=0x%08"PRIxGENPADDR",size=0x%08"PRIx64","
205            "type=%"PRIu8",remote_rels=%d%d%d,extra=%s}\n",
206            indent_buff,
207            cte, node->left, node->right, node->end, node->end_root,
208            node->level, get_address(C(cte)), get_size(C(cte)),
209            (uint8_t)C(cte)->type, node->remote_copies,
210            node->remote_ancs, node->remote_descs,extra);
211    return;
212}
213
214void
215mdb_dump(struct cte *cte, int indent)
216{
217    // Print a tree with root on the left, the smallest element at the top and the
218    // largest at the bottom.
219
220    /* make an indent buffer */
221    char indent_buff[indent+2];
222    for (int i=0; i < indent+1; i++) {
223        indent_buff[i]='\t';
224    }
225    indent_buff[indent+1] = '\0';
226
227    if (!cte) {
228        printf("NULL{}\n");
229        return;
230    }
231
232    struct mdbnode *node = N(cte);
233    assert(node);
234
235    if (node->left) {
236        if (node->left == cte) {
237            printf("%sSELF!!!!\n", indent_buff);
238        }
239        else {
240            mdb_dump(node->left, indent+1);
241        }
242    }
243
244    print_cte(cte, indent_buff);
245
246    if (node->right) {
247        if (node->right == cte) {
248            printf("%sSELF!!!!\n", indent_buff);
249        }
250        else {
251            mdb_dump(node->right, indent+1);
252        }
253    }
254}
255
256
257/*
258 * Invariant checking.
259 */
260
261static int
262mdb_check_subtree_invariants(struct cte *cte)
263{
264    if (!cte) {
265        return MDB_INVARIANT_OK;
266    }
267    if (C(cte)->type == 0) {
268        mdb_dump_all_the_things();
269    }
270    assert(C(cte)->type != 0);
271
272    int err;
273    struct mdbnode *node = N(cte);
274
275    if (node->level > 0 && !(node->left && node->right)) {
276        MDB_RET_INVARIANT(cte, MDB_INVARIANT_BOTHCHILDREN);
277    }
278    if (node->left && !(N(node->left)->level < node->level)) {
279        MDB_RET_INVARIANT(cte, MDB_INVARIANT_LEFT_LEVEL_LESS);
280    }
281    if (node->right && !(N(node->right)->level <= node->level)) {
282        MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHT_LEVEL_LEQ);
283    }
284    if (node->right && N(node->right)->right &&
285        !(N(N(node->right)->right)->level < node->level))
286    {
287        MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHTRIGHT_LEVEL_LESS);
288    }
289    if (node->right && N(node->right)->left &&
290        !(N(N(node->right)->left)->level < node->level))
291    {
292        MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHTLEFT_LEVEL_LESS);
293    }
294
295    // build expected end root for current node
296    mdb_root_t expected_end_root = get_type_root(C(cte)->type);
297    if (node->left) {
298        expected_end_root = MAX(expected_end_root, N(node->left)->end_root);
299    }
300    if (node->right) {
301        expected_end_root = MAX(expected_end_root, N(node->right)->end_root);
302    }
303    if (node->end_root != expected_end_root) {
304        MDB_RET_INVARIANT(cte, MDB_INVARIANT_END_IS_MAX);
305    }
306
307    // build expected end for current node. this is complex because the root
308    // acts as an address prefix, so only ends where the corresponding root is
309    // the maximum may be considered.
310    genpaddr_t expected_end = 0;
311    if (get_type_root(C(cte)->type) == node->end_root) {
312        // only consider current cte's end if its root is node->end_root
313        expected_end = get_address(C(cte))+get_size(C(cte));
314    }
315    if (node->left && N(node->left)->end_root == node->end_root) {
316        // only consider left child end if its end_root is node->end_root
317        expected_end = MAX(expected_end, N(node->left)->end);
318    }
319    if (node->right && N(node->right)->end_root == node->end_root) {
320        // only consider right child end if its end_root is node->end_root
321        expected_end = MAX(expected_end, N(node->right)->end);
322    }
323    if (node->end != expected_end) {
324        MDB_RET_INVARIANT(cte, MDB_INVARIANT_END_IS_MAX);
325    }
326
327    if (node->left) {
328        assert(node->left != cte);
329        if (compare_caps(C(node->left), C(cte), true) >= 0) {
330            MDB_RET_INVARIANT(cte, MDB_INVARIANT_LEFT_SMALLER);
331        }
332        err = mdb_check_subtree_invariants(node->left);
333        if (err) {
334            return err;
335        }
336    }
337
338    if (node->right) {
339        assert(node->right != cte);
340        if (compare_caps(C(node->right), C(cte), true) <= 0) {
341            MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHT_GREATER);
342        }
343        err = mdb_check_subtree_invariants(node->right);
344        if (err) {
345            return err;
346        }
347    }
348
349    return MDB_INVARIANT_OK;
350}
351
352int
353mdb_check_invariants(void)
354{
355    int res = mdb_check_subtree_invariants(mdb_root);
356    if (res != 0) {
357        printf("mdb_check_invariants() -> %d\n", res);
358    }
359    return res;
360}
361
362static bool
363mdb_is_reachable(struct cte *root, struct cte *cte)
364{
365    if (!root) {
366        return false;
367    }
368    if (root == cte) {
369        return true;
370    }
371    if (N(root)->left) {
372        if (mdb_is_reachable(N(root)->left, cte)) {
373            return true;
374        }
375    }
376    if (N(root)->right) {
377        if (mdb_is_reachable(N(root)->right, cte)) {
378            return true;
379        }
380    }
381    return false;
382}
383
384/*
385 * General internal helpers.
386 */
387
388static void
389mdb_update_end(struct cte *cte)
390{
391    if (!cte) {
392        return;
393    }
394    struct mdbnode *node = N(cte);
395
396    // build end root for current node
397    mdb_root_t end_root = get_type_root(C(cte)->type);
398    if (node->left) {
399        end_root = MAX(end_root, N(node->left)->end_root);
400    }
401    if (node->right) {
402        end_root = MAX(end_root, N(node->right)->end_root);
403    }
404    node->end_root = end_root;
405
406    // build end address for current node. this is complex because the root
407    // acts as an address prefix, so only ends where the corresponding root is
408    // the maximum may be considered.
409    genpaddr_t end = 0;
410    if (get_type_root(C(cte)->type) == node->end_root) {
411        // only consider current cte's end if its root is node->end_root
412        end = get_address(C(cte))+get_size(C(cte));
413    }
414    if (node->left && N(node->left)->end_root == node->end_root) {
415        // only consider left child end if its end_root is node->end_root
416        end = MAX(end, N(node->left)->end);
417    }
418    if (node->right && N(node->right)->end_root == node->end_root) {
419        // only consider right child end if its end_root is node->end_root
420        end = MAX(end, N(node->right)->end);
421    }
422    node->end = end;
423}
424
425static struct cte*
426mdb_skew(struct cte *node)
427{
428    /* transform invalid state
429     *
430     *               |
431     *      |L|<---|T|
432     *     /   \      \
433     *   |A|   |B|    |R|
434     *
435     * to valid equivalent state
436     *
437     *       |
438     *      |L|--->|T|
439     *     /      /   \
440     *   |A|    |B|   |R|
441     *
442     */
443    if (!node || !N(node)->left) {
444        return node;
445    }
446    else if (N(node)->level == N(N(node)->left)->level) {
447        struct cte *left = N(node)->left;
448        N(node)->left = N(left)->right;
449        N(left)->right = node;
450        mdb_update_end(node);
451        mdb_update_end(left);
452
453        // need to update mdb_root
454        if (mdb_root == node) {
455            set_root(left);
456        }
457        return left;
458    }
459    else {
460        return node;
461    }
462}
463
464static struct cte*
465mdb_split(struct cte *node)
466{
467    /* transform invalid state
468     *
469     *       |
470     *      |T|--->|R|-->|X|
471     *     /      /
472     *   |A|    |B|
473     *
474     * to valid equivalent state
475     *
476     *             |
477     *            |R|
478     *           /   \
479     *         |T|   |X|
480     *        /   \
481     *      |A|   |B|
482     *
483     */
484    if (!node || !N(node)->right || !N(N(node)->right)->right) {
485        return node;
486    }
487    else if (N(node)->level == N(N(N(node)->right)->right)->level) {
488        struct cte *right = N(node)->right;
489        N(node)->right = N(right)->left;
490        N(right)->left = node;
491        N(right)->level += 1;
492        mdb_update_end(node);
493        mdb_update_end(right);
494
495        // need to update mdb_root
496        if (mdb_root == node) {
497            set_root(right);
498        }
499        return right;
500    }
501    else {
502        return node;
503    }
504}
505
506static void
507mdb_decrease_level(struct cte *node)
508{
509    assert(node);
510
511    mdb_level_t expected;
512    if (!N(node)->left || !N(node)->right) {
513        expected = 0;
514    }
515    else {
516        expected = MIN(N(N(node)->left)->level, N(N(node)->right)->level) + 1;
517    }
518
519    if (expected < N(node)->level) {
520        N(node)->level = expected;
521        if (N(node)->right && expected < N(N(node)->right)->level) {
522            N(N(node)->right)->level = expected;
523        }
524    }
525}
526
527static struct cte*
528mdb_rebalance(struct cte *node)
529{
530    assert(node);
531    mdb_update_end(node);
532    mdb_decrease_level(node);
533    node = mdb_skew(node);
534    N(node)->right = mdb_skew(N(node)->right);
535    if (N(node)->right) {
536        N(N(node)->right)->right = mdb_skew(N(N(node)->right)->right);
537    }
538    node = mdb_split(node);
539    N(node)->right = mdb_split(N(node)->right);
540    return node;
541}
542
543#ifndef NDEBUG
544static bool
545mdb_is_child(struct cte *child, struct cte *parent)
546{
547    if (!parent) {
548        return mdb_root == child;
549    }
550    else {
551        return N(parent)->left == child || N(parent)->right == child;
552    }
553}
554#else
555#define mdb_is_child(a, b) 0
556#endif
557
558static bool
559mdb_is_inside(genpaddr_t outer_begin, genpaddr_t outer_end,
560              genpaddr_t inner_begin, genpaddr_t inner_end)
561{
562    assert(outer_begin <= outer_end);
563    assert(inner_begin <= inner_end);
564    return
565        (inner_begin >= outer_begin && inner_end < outer_end) ||
566        (inner_begin > outer_begin && inner_end <= outer_end);
567}
568
569/*
570 * Operations and operation-specific helpers.
571 */
572
573static errval_t
574mdb_sub_insert(struct cte *new_node, struct cte **current)
575{
576    errval_t err;
577    assert(new_node);
578    assert(current);
579    MDB_TRACE_ENTER(*current, "%p, %p (*%p)", new_node, *current, current);
580
581    struct cte *current_ = *current;
582
583    if (!current_) {
584        // we've reached an empty leaf, insert here
585        *current = new_node;
586        mdb_update_end(new_node);
587        return SYS_ERR_OK;
588    }
589
590    int compare = compare_caps(C(new_node), C(current_), true);
591    if (compare < 0) {
592        // new_node < current
593        err = mdb_sub_insert(new_node, &N(current_)->left);
594        if (err_is_fail(err)) {
595            return err;
596        }
597    }
598    else if (compare > 0) {
599        // new_node > current
600        err = mdb_sub_insert(new_node, &N(current_)->right);
601        if (err_is_fail(err)) {
602            return err;
603        }
604    }
605    else {
606        return CAPS_ERR_MDB_DUPLICATE_ENTRY;
607    }
608
609    mdb_update_end(current_);
610    current_ = mdb_skew(current_);
611    current_ = mdb_split(current_);
612    *current = current_;
613
614    err = SYS_ERR_OK;
615    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
616}
617
618errval_t
619mdb_insert(struct cte *new_node)
620{
621    MDB_TRACE_ENTER(mdb_root, "%p", new_node);
622#ifdef IN_KERNEL
623#ifdef MDB_TRACE_NO_RECURSIVE
624    char prefix[50];
625    snprintf(prefix, 50, "mdb_insert.%d: ", my_core_id);
626    print_cte(new_node, prefix);
627#endif
628#endif
629    errval_t ret = mdb_sub_insert(new_node, &mdb_root);
630    CHECK_INVARIANTS(mdb_root, new_node, true);
631    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, ret, mdb_root);
632}
633
634static void
635mdb_exchange_child(struct cte *first, struct cte *first_parent,
636                   struct cte *second)
637{
638    assert(mdb_is_child(first, first_parent));
639
640    if (!first_parent) {
641        set_root(second);
642    }
643    else if (N(first_parent)->left == first) {
644        N(first_parent)->left = second;
645    }
646    else if (N(first_parent)->right == first) {
647        N(first_parent)->right = second;
648    }
649    else {
650        assert(!"first is not child of first_parent");
651    }
652}
653
654static void
655mdb_exchange_nodes(struct cte *first, struct cte *first_parent,
656                   struct cte *second, struct cte *second_parent)
657{
658    struct cte *tmp_node;
659    mdb_level_t tmp_level;
660
661    mdb_exchange_child(first, first_parent, second);
662    mdb_exchange_child(second, second_parent, first);
663
664    tmp_node = N(first)->left;
665    N(first)->left = N(second)->left;
666    N(second)->left = tmp_node;
667
668    tmp_node = N(first)->right;
669    N(first)->right = N(second)->right;
670    N(second)->right = tmp_node;
671
672    tmp_level = N(first)->level;
673    N(first)->level = N(second)->level;
674    N(second)->level = tmp_level;
675
676    mdb_update_end(first);
677    mdb_update_end(second);
678
679    assert(mdb_is_reachable(mdb_root, first));
680    assert(mdb_is_reachable(mdb_root, second));
681}
682
683static void
684mdb_exchange_remove(struct cte *target, struct cte *target_parent,
685                    struct cte **current, struct cte *parent,
686                    int dir, struct cte **ret_target)
687{
688    assert(current);
689    MDB_TRACE_ENTER(*current, "%p, %p, %p (*%p), %p, %d", target, target_parent, *current, current, parent, dir);
690    assert(target);
691    assert(*current);
692    assert(parent);
693    assert(C(target)->type != 0);
694    assert(C(*current)->type != 0);
695    assert(C(parent)->type != 0);
696    assert(ret_target);
697    assert(!*ret_target);
698    assert(dir != 0);
699    assert(compare_caps(C(target), C(*current), true) != 0);
700    assert(mdb_is_child(target, target_parent));
701    assert(mdb_is_child(*current, parent));
702    assert(mdb_is_reachable(mdb_root, target));
703
704    struct cte *current_ = *current;
705
706    if (dir > 0) {
707        if (parent == target) {
708            assert(N(parent)->left == current_);
709        }
710        else {
711            assert(N(parent)->right == current_);
712        }
713
714        if (N(current_)->right) {
715            mdb_exchange_remove(target, target_parent, &N(current_)->right,
716                                current_, dir, ret_target);
717        }
718    }
719    else if (dir < 0) {
720        if (parent == target) {
721            assert(N(parent)->right == current_);
722        }
723        else {
724            assert(N(parent)->left == current_);
725        }
726
727        if (N(current_)->left) {
728            mdb_exchange_remove(target, target_parent, &N(current_)->left,
729                                current_, dir, ret_target);
730        }
731        else if (N(current_)->right) {
732            assert(N(current_)->level == 0);
733            // right is non-null, left null -> current is level 0 node with
734            // horizontal right link, and is also the successor of the target.
735            // in this case, exchange current and current right, then current
736            // (at its new position) and the target.
737            struct cte *new_current = N(current_)->right;
738            mdb_exchange_nodes(current_, parent, N(current_)->right, current_);
739            mdb_exchange_nodes(target, target_parent, current_, new_current);
740            // "current" is now located where the target was, further up in the
741            // tree. "new_current" is the node where current was. "target" is
742            // where current->right was, and is a leaf, so can be dropped.
743            assert(N(new_current)->right == target);
744            N(new_current)->right = NULL;
745            *ret_target = current_;
746            *current = new_current;
747            assert(!mdb_is_reachable(mdb_root, target));
748            MDB_TRACE_LEAVE_SUB(NULL);
749        }
750    }
751
752    if (*ret_target) {
753        assert(!mdb_is_reachable(mdb_root, target));
754        // implies we recursed further down to find a leaf. need to rebalance.
755        current_ = mdb_rebalance(current_);
756        *current = current_;
757        MDB_TRACE_LEAVE_SUB(current_);
758    }
759    else {
760        //printf("found leaf %p\n", current_);
761        // found successor/predecessor leaf, exchange with target
762        assert(!N(current_)->right && !N(current_)->left);
763        mdb_exchange_nodes(target, target_parent, current_, parent);
764
765        // "current" is now where target was, so set as ret_target
766        *ret_target = current_;
767        // target would be the new current, but we're removing it, so set
768        // current to null. This also sets parent's corresponding child to
769        // null by recursion.
770        *current = NULL;
771        MDB_TRACE_LEAVE_SUB(NULL);
772    }
773}
774
775static errval_t
776mdb_subtree_remove(struct cte *target, struct cte **current, struct cte *parent)
777{
778    assert(current);
779    MDB_TRACE_ENTER(*current, "%p, %p (*%p), %p", target, *current, current, parent);
780
781    errval_t err;
782    struct cte *current_ = *current;
783    if (!current_) {
784        err = CAPS_ERR_MDB_ENTRY_NOTFOUND;
785        MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
786    }
787
788    int compare = compare_caps(C(target), C(current_), true);
789    if (compare > 0) {
790        err = mdb_subtree_remove(target, &N(current_)->right, current_);
791        if (err != SYS_ERR_OK) {
792            MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
793            return err;
794        }
795    }
796    else if (compare < 0) {
797        err = mdb_subtree_remove(target, &N(current_)->left, current_);
798        if (err != SYS_ERR_OK) {
799            MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
800        }
801    }
802    else {
803        assert(current_ == target);
804        if (!N(current_)->left && !N(current_)->right) {
805            // target is leaf, just remove
806            *current = NULL;
807            err = SYS_ERR_OK;
808            MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, NULL);
809        }
810        else if (!N(current_)->left) {
811            // move to right child then go left (dir=-1)
812            // curr, new_right = xchg_rm(elem, parent, current.right, current, -1)
813            struct cte *new_current = NULL;
814            struct cte *new_right = N(current_)->right;
815            mdb_exchange_remove(target, parent, &new_right, current_, -1,
816                                &new_current);
817            assert(new_current);
818            current_ = new_current;
819            N(current_)->right = new_right;
820            assert(!mdb_is_reachable(mdb_root, target));
821        }
822        else {
823            // move to left child then go right (dir=1)
824            // curr, new_left = xchg_rm(elem, parent, current.left, current, 1)
825            struct cte *new_current = NULL;
826            struct cte *new_left = N(current_)->left;
827            mdb_exchange_remove(target, parent, &new_left, current_, 1,
828                                &new_current);
829            assert(new_current);
830            current_ = new_current;
831            N(current_)->left = new_left;
832            assert(!mdb_is_reachable(mdb_root, target));
833        }
834    }
835
836    // rebalance after remove from subtree
837    current_ = mdb_rebalance(current_);
838    *current = current_;
839
840    assert(C(target)->type != 0);
841    assert(!*current || C(*current)->type != 0);
842
843    err = SYS_ERR_OK;
844    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
845}
846
847errval_t
848mdb_remove(struct cte *target)
849{
850    MDB_TRACE_ENTER(mdb_root, "%p", target);
851    CHECK_INVARIANTS(mdb_root, target, true);
852#ifdef IN_KERNEL
853#ifdef MDB_TRACE_NO_RECURSIVE
854    char prefix[50];
855    snprintf(prefix, 50, "mdb_remove.%d: ", my_core_id);
856    print_cte(target, prefix);
857#endif
858#endif
859    errval_t err = mdb_subtree_remove(target, &mdb_root, NULL);
860    CHECK_INVARIANTS(mdb_root, target, false);
861    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, mdb_root);
862}
863
864/*
865 * Queries on the ordering.
866 */
867
868static struct cte*
869mdb_sub_find_equal(struct capability *cap, struct cte *current)
870{
871    if (!current) {
872        return NULL;
873    }
874    int compare = compare_caps(cap, C(current), false);
875    if (compare < 0) {
876        // current is gt key, look for smaller node
877        return mdb_sub_find_equal(cap, N(current)->left);
878    }
879    else if (compare > 0) {
880        // current is lt key, attempt to find bigger current
881        return mdb_sub_find_equal(cap, N(current)->right);
882    }
883    else {
884        return current;
885    }
886}
887
888struct cte*
889mdb_find_equal(struct capability *cap)
890{
891    return mdb_sub_find_equal(cap, mdb_root);
892}
893
894static struct cte*
895mdb_sub_find_less(struct capability *cap, struct cte *current, bool equal_ok,
896                  bool tiebreak)
897{
898    if (!current) {
899        return NULL;
900    }
901    int compare = compare_caps(cap, C(current), tiebreak);
902    if (compare < 0) {
903        // current is gt key, look for smaller node
904        return mdb_sub_find_less(cap, N(current)->left, equal_ok, tiebreak);
905    }
906    else if (compare > 0) {
907        // current is lt key, attempt to find bigger current
908        struct cte *res = mdb_sub_find_less(cap, N(current)->right, equal_ok,
909                                            tiebreak);
910        if (res) {
911            return res;
912        }
913        // bigger child exceeded key
914        return current;
915    }
916    else {
917        // found equal element
918        if (equal_ok) {
919            return current;
920        }
921        else {
922            // look for smaller node
923            return mdb_sub_find_less(cap, N(current)->left, equal_ok,
924                                     tiebreak);
925        }
926    }
927}
928
929struct cte*
930mdb_find_less(struct capability *cap, bool equal_ok)
931{
932    return mdb_sub_find_less(cap, mdb_root, equal_ok, false);
933}
934
935static struct cte*
936mdb_sub_find_greater(struct capability *cap, struct cte *current,
937                     bool equal_ok, bool tiebreak)
938{
939    if (!current) {
940        return NULL;
941    }
942    int compare = compare_caps(cap, C(current), tiebreak);
943    if (compare < 0) {
944        // current is gt key, attempt to find smaller node
945        struct cte *res = mdb_sub_find_greater(cap, N(current)->left, equal_ok,
946                                               tiebreak);
947        if (res) {
948            return res;
949        }
950        // smaller was lte key
951        return current;
952    }
953    else if (compare > 0) {
954        // current is lte key, look for greater node
955        return mdb_sub_find_greater(cap, N(current)->right, equal_ok,
956                                    tiebreak);
957    }
958    else {
959        // found equal element
960        if (equal_ok) {
961            return current;
962        }
963        else {
964            // look for greater node
965            return mdb_sub_find_greater(cap, N(current)->right, equal_ok,
966                                        tiebreak);
967        }
968    }
969}
970
971struct cte*
972mdb_find_greater(struct capability *cap, bool equal_ok)
973{
974    return mdb_sub_find_greater(cap, mdb_root, equal_ok, false);
975}
976
977struct cte*
978mdb_predecessor(struct cte *current)
979{
980    struct mdbnode *node = N(current);
981    if (node->left) {
982        // if possible, look just at children
983        current = node->left;
984        while ((node = N(current))->right) {
985            current = node->right;
986        }
987        return current;
988    }
989    // XXX: in lieu of a parent pointer that can be used to traverse upwards,
990    // we have to perform a search through the tree from the root. This makes
991    // "predecessor" into a O(log(n)) operation, instead of the expected O(1).
992    return mdb_sub_find_less(C(current), mdb_root, false, true);
993}
994
995struct cte*
996mdb_successor(struct cte *current)
997{
998    struct mdbnode *node = N(current);
999    if (node->right) {
1000        // if possible, look just at children
1001        current = node->right;
1002        while ((node = N(current))->left) {
1003            current = node->left;
1004        }
1005        return current;
1006    }
1007    // XXX: in lieu of a parent pointer that can be used to traverse upwards,
1008    // we have perform a search through the tree from the root. This makes
1009    // "successor" into a O(log(n)) operation, instead of the expected O(1).
1010    return mdb_sub_find_greater(C(current), mdb_root, false, true);
1011}
1012
1013/*
1014 * The range query.
1015 */
1016
1017static struct cte*
1018mdb_choose_surrounding(genpaddr_t address, size_t size, struct cte *first,
1019                       struct cte *second)
1020{
1021    assert(first);
1022    assert(second);
1023    assert(get_type_root(C(first)->type) == get_type_root(C(second)->type));
1024#ifndef NDEBUG
1025    genpaddr_t beg = address, end = address + size;
1026    genpaddr_t fst_beg = get_address(C(first));
1027    genpaddr_t snd_beg = get_address(C(second));
1028    genpaddr_t fst_end = fst_beg + get_size(C(first));
1029    genpaddr_t snd_end = snd_beg + get_size(C(second));
1030    assert(fst_beg <= beg && fst_end >= end);
1031    assert(snd_beg <= beg && snd_end >= end);
1032#endif
1033
1034    if (compare_caps(C(first), C(second), true) >= 0) {
1035        return first;
1036    }
1037    else {
1038        return second;
1039    }
1040}
1041
1042static struct cte*
1043mdb_choose_inner(genpaddr_t address, size_t size, struct cte *first,
1044                 struct cte *second)
1045{
1046    assert(first);
1047    assert(second);
1048    assert(get_type_root(C(first)->type) == get_type_root(C(second)->type));
1049#ifndef NDEBUG
1050    genpaddr_t end = address + size;
1051    genpaddr_t fst_beg = get_address(C(first));
1052    genpaddr_t snd_beg = get_address(C(second));
1053    genpaddr_t fst_end = fst_beg + get_size(C(first));
1054    genpaddr_t snd_end = snd_beg + get_size(C(second));
1055    assert(mdb_is_inside(address, end, fst_beg, fst_end));
1056    assert(mdb_is_inside(address, end, snd_beg, snd_end));
1057#endif
1058
1059    if (compare_caps(C(first), C(second), true) <= 0) {
1060        return first;
1061    }
1062    else {
1063        return second;
1064    }
1065}
1066
1067static struct cte*
1068mdb_choose_partial(genpaddr_t address, size_t size, struct cte *first,
1069                   struct cte *second)
1070{
1071    assert(first);
1072    assert(second);
1073    assert(get_type_root(C(first)->type) == get_type_root(C(second)->type));
1074    genpaddr_t beg = address;
1075    genpaddr_t fst_beg = get_address(C(first));
1076    genpaddr_t snd_beg = get_address(C(second));
1077#ifndef NDEBUG
1078    genpaddr_t end = address + size;
1079    genpaddr_t fst_end = fst_beg + get_size(C(first));
1080    genpaddr_t snd_end = snd_beg + get_size(C(second));
1081    assert(fst_beg < end);
1082    assert(snd_beg < end);
1083    assert(fst_end > beg);
1084    assert(snd_end > beg);
1085    assert(fst_beg != beg);
1086    assert(snd_beg != beg);
1087    assert(fst_end != end);
1088    assert(snd_end != end);
1089    assert((fst_beg < beg) == (fst_end < end));
1090    assert((snd_beg < beg) == (snd_end < end));
1091#endif
1092
1093    if (fst_beg < beg && snd_beg > beg) {
1094        return first;
1095    }
1096    else if (snd_beg < beg && fst_beg > beg) {
1097        return second;
1098    }
1099    else {
1100        if (compare_caps(C(first), C(second), true) >= 0) {
1101            return first;
1102        }
1103        else {
1104            return second;
1105        }
1106    }
1107}
1108
1109static int
1110mdb_sub_find_range(mdb_root_t root, genpaddr_t address, size_t size,
1111                   int max_precision, struct cte *current,
1112                   /*out*/ struct cte **ret_node);
1113
1114static void
1115mdb_sub_find_range_merge(mdb_root_t root, genpaddr_t address, size_t size,
1116                         int max_precision, struct cte *sub,
1117                         /*inout*/ int *ret, /*inout*/ struct cte **result)
1118{
1119    assert(sub);
1120    assert(ret);
1121    assert(result);
1122    assert(max_precision >= 0);
1123    assert(*ret <= max_precision);
1124
1125    struct cte *sub_result = NULL;
1126    int sub_ret = mdb_sub_find_range(root, address, size, max_precision, sub,
1127                                     &sub_result);
1128    if (sub_ret > max_precision) {
1129        *result = sub_result;
1130        *ret = sub_ret;
1131    }
1132    else if (sub_ret > *ret) {
1133        *result = sub_result;
1134        *ret = sub_ret;
1135    }
1136    else if (sub_ret == *ret) {
1137        switch (sub_ret) {
1138        case MDB_RANGE_NOT_FOUND:
1139            break;
1140        case MDB_RANGE_FOUND_SURROUNDING:
1141            *result = mdb_choose_surrounding(address, size, *result, sub_result);
1142            break;
1143        case MDB_RANGE_FOUND_INNER:
1144            *result = mdb_choose_inner(address, size, *result, sub_result);
1145            break;
1146        case MDB_RANGE_FOUND_PARTIAL:
1147            *result = mdb_choose_partial(address, size, *result, sub_result);
1148            break;
1149        default:
1150            assert(!"Unhandled enum value for mdb_find_range result");
1151            break;
1152        }
1153    }
1154    // else ret > sub_ret, keep ret & result as is
1155}
1156
1157static int
1158mdb_sub_find_range(mdb_root_t root, genpaddr_t address, size_t size,
1159                   int max_precision, struct cte *current,
1160                   /*out*/ struct cte **ret_node)
1161{
1162    assert(max_precision >= 0);
1163    assert(ret_node);
1164
1165    if (!current) {
1166        *ret_node = NULL;
1167        return MDB_RANGE_NOT_FOUND;
1168    }
1169
1170    if (N(current)->end_root < root) {
1171        *ret_node = NULL;
1172        return MDB_RANGE_NOT_FOUND;
1173    }
1174    if (N(current)->end_root == root && N(current)->end <= address) {
1175        *ret_node = NULL;
1176        return MDB_RANGE_NOT_FOUND;
1177    }
1178
1179    mdb_root_t current_root = get_type_root(C(current)->type);
1180
1181    struct cte *result = NULL;
1182    int ret = MDB_RANGE_NOT_FOUND;
1183
1184    genpaddr_t current_address = get_address(C(current));
1185    genpaddr_t current_end = current_address + get_size(C(current));
1186    genpaddr_t search_end = address + size;
1187
1188    if (current_root == root) {
1189
1190        if (ret < MDB_RANGE_FOUND_PARTIAL &&
1191            current_address > address &&
1192            current_address < search_end &&
1193            current_end > search_end)
1194        {
1195            result = current;
1196            ret = MDB_RANGE_FOUND_PARTIAL;
1197        }
1198        if (ret < MDB_RANGE_FOUND_PARTIAL &&
1199            current_end > address &&
1200            current_end < search_end &&
1201            current_address < address)
1202        {
1203            result = current;
1204            ret = MDB_RANGE_FOUND_PARTIAL;
1205        }
1206        if (ret < MDB_RANGE_FOUND_INNER &&
1207            mdb_is_inside(address, search_end, current_address, current_end))
1208        {
1209            result = current;
1210            ret = MDB_RANGE_FOUND_INNER;
1211        }
1212        if (ret < MDB_RANGE_FOUND_SURROUNDING &&
1213            current_address <= address &&
1214            // exclude 0-length match with curaddr==addr
1215            current_address < search_end &&
1216            current_end >= search_end &&
1217            // exclude 0-length match with currend==addr
1218            current_end > address)
1219        {
1220            result = current;
1221            ret = MDB_RANGE_FOUND_SURROUNDING;
1222        }
1223        if (ret > max_precision) {
1224            *ret_node = result;
1225            return ret;
1226        }
1227    }
1228
1229    if (N(current)->left) {
1230        mdb_sub_find_range_merge(root, address, size, max_precision,
1231                                 N(current)->left, /*inout*/&ret,
1232                                 /*inout*/&result);
1233        if (ret > max_precision) {
1234            *ret_node = result;
1235            return ret;
1236        }
1237    }
1238
1239    if (N(current)->right && root >= current_root &&
1240        (search_end > current_address || (search_end == current_address && size == 0))) {
1241        mdb_sub_find_range_merge(root, address, size, max_precision,
1242                                 N(current)->right, /*inout*/&ret,
1243                                 /*inout*/&result);
1244        if (ret > max_precision) {
1245            *ret_node = result;
1246            return ret;
1247        }
1248    }
1249
1250    *ret_node = result;
1251    return ret;
1252
1253}
1254
1255errval_t
1256mdb_find_range(mdb_root_t root, genpaddr_t address, gensize_t size,
1257               int max_result, /*out*/ struct cte **ret_node,
1258               /*out*/ int *result)
1259{
1260    if (max_result < MDB_RANGE_NOT_FOUND ||
1261        max_result > MDB_RANGE_FOUND_PARTIAL)
1262    {
1263        return CAPS_ERR_INVALID_ARGS;
1264    }
1265    if (max_result > MDB_RANGE_NOT_FOUND && !ret_node) {
1266        return CAPS_ERR_INVALID_ARGS;
1267    }
1268    if (!result) {
1269        return CAPS_ERR_INVALID_ARGS;
1270    }
1271
1272    struct cte *alt_ret_node;
1273    if (!ret_node) {
1274        ret_node = &alt_ret_node;
1275    }
1276
1277    *result = mdb_sub_find_range(root, address, size, max_result, mdb_root, ret_node);
1278    return SYS_ERR_OK;
1279}
1280
1281errval_t
1282mdb_find_cap_for_address(genpaddr_t address, struct cte **ret_node)
1283{
1284    int result;
1285    errval_t err;
1286    // query for size 1 to get the smallest cap that includes the byte at the
1287    // given address
1288    err = mdb_find_range(get_type_root(ObjType_RAM), address,
1289                         1, MDB_RANGE_FOUND_SURROUNDING, ret_node, &result);
1290    if (err_is_fail(err)) {
1291        return err;
1292    }
1293    if (result != MDB_RANGE_FOUND_SURROUNDING) {
1294        return SYS_ERR_CAP_NOT_FOUND;
1295    }
1296    return SYS_ERR_OK;
1297}
1298
1299bool mdb_reachable(struct cte *cte)
1300{
1301    return mdb_is_reachable(mdb_root, cte);
1302}
1303
1304errval_t
1305mdb_traverse(enum mdb_tree_traversal_order order, mdb_tree_traversal_fn cb, void *data)
1306{
1307    return mdb_traverse_subtree(mdb_root, order, cb, data);
1308}
1309
1310errval_t
1311mdb_traverse_subtree(struct cte *cte, enum mdb_tree_traversal_order order,
1312        mdb_tree_traversal_fn cb, void *data)
1313{
1314    struct mdbnode *node = N(cte);
1315    assert(node);
1316
1317    struct cte *first, *second;
1318    if (order == MDB_TRAVERSAL_ORDER_ASCENDING) {
1319        first = node->left;
1320        second = node->right;
1321    } else {
1322        first = node->right;
1323        second = node->left;
1324    }
1325
1326    errval_t err;
1327
1328    if (first) {
1329        err = mdb_traverse_subtree(first, order, cb, data);
1330        if (err_is_fail(err)) {
1331            return err;
1332        }
1333    }
1334
1335    err = cb(cte, data);
1336
1337    if (err_is_fail(err)) {
1338        return err;
1339    }
1340
1341    if (second) {
1342        err = mdb_traverse_subtree(second, order, cb, data);
1343        if (err_is_fail(err)) {
1344            return err;
1345        }
1346    }
1347    return SYS_ERR_OK;
1348}
1349