1#include <mdb/mdb_tree.h>
2#include <mdb/mdb.h>
3#include <cap_predicates.h>
4#include <barrelfish_kpi/capabilities.h>
5#include <capabilities.h>
6#include <assert.h>
7#include <stdio.h>
8#if IN_KERNEL
9#include <kernel.h>
10#include <kcb.h>
11#endif
12
13#ifndef MIN
14#define MIN(a, b) ((a)<(b)?(a):(b))
15#endif
16#ifndef MAX
17#define MAX(a, b) ((a)>(b)?(a):(b))
18#endif
19
20#ifdef N
21#undef N
22#endif
23#define N(cte) (&(cte)->mdbnode)
24#ifdef C
25#undef C
26#endif
27#define C(cte) (&(cte)->cap)
28
29// define panic() for user-land build
30#ifndef IN_KERNEL
31#define panic(msg...) \
32    do { \
33        printf(msg); \
34        abort(); \
35    }while(0)
36#endif
37
38// PP switch to change behaviour if invariants fail
39#ifdef MDB_FAIL_INVARIANTS
40// on failure, dump mdb and terminate
41__attribute__((noreturn))
42static void
43mdb_dump_and_fail(struct cte *cte, enum mdb_invariant failure)
44{
45    mdb_dump(cte, 0);
46    panic("failed on cte %p with failure %s (%d)\n",
47          cte, mdb_invariant_to_str(failure), failure);
48}
49#define MDB_RET_INVARIANT(cte, failure) mdb_dump_and_fail(cte, failure)
50#else
51#define MDB_RET_INVARIANT(cte, failure) return failure
52#endif
53
54// PP switch to toggle top-level checking of invariants
55#ifdef MDB_CHECK_INVARIANTS
56#define CHECK_INVARIANTS(root, cte, reach) \
57do { \
58    if (mdb_is_reachable(root, cte) != reach) { \
59        panic("mdb_is_reachable(%p,%p) != %d", root, cte, reach); \
60    } \
61    mdb_check_subtree_invariants(cte); \
62} while(0)
63#else
64#define CHECK_INVARIANTS(root, cte, reach) ((void)0)
65#endif
66
67// PP switch to toggle recursive checking of invariants by default
68#ifdef MDB_RECHECK_INVARIANTS
69// disable toplevel invariants checks except for the assertion clause as we're
70// doing them anyway in CHECK_INVARIANTS_SUB
71#undef CHECK_INVARIANTS
72#define CHECK_INVARIANTS(root, cte, reach) \
73do { \
74    if (mdb_is_reachable(root, cte) != reach) { \
75        panic("mdb_is_reachable(%p,%p) != %d", root, cte, reach); \
76    } \
77} while(0)
78#define CHECK_INVARIANTS_SUB(cte) mdb_check_subtree_invariants(cte)
79#else
80#define CHECK_INVARIANTS_SUB(cte) ((void)0)
81#endif
82
83// printf tracing and entry/exit invariant checking
84#ifdef MDB_TRACE
85#define MDB_TRACE_ENTER(valid_cte, args_fmt, ...) do { \
86    printf("enter %s(" args_fmt ")\n", __func__, __VA_ARGS__); \
87    CHECK_INVARIANTS_SUB(valid_cte); \
88} while (0)
89#define MDB_TRACE_LEAVE_SUB(valid_cte) do { \
90    CHECK_INVARIANTS_SUB(valid_cte); \
91    printf("leave %s\n", __func__); \
92    return; \
93} while (0)
94#define MDB_TRACE_LEAVE_SUB_RET(ret_fmt, ret, valid_cte) do { \
95    CHECK_INVARIANTS_SUB(valid_cte); \
96    printf("leave %s->" ret_fmt "\n", __func__, (ret)); \
97    return (ret); \
98} while (0)
99#else
100#define MDB_TRACE_ENTER(valid_cte, args_fmt, ...) CHECK_INVARIANTS_SUB(valid_cte)
101#define MDB_TRACE_LEAVE_SUB(valid_cte) do { \
102    CHECK_INVARIANTS_SUB(valid_cte); \
103    return; \
104} while (0)
105#define MDB_TRACE_LEAVE_SUB_RET(ret_fmt, ret, valid_cte) do { \
106    CHECK_INVARIANTS_SUB(valid_cte); \
107    return (ret); \
108} while (0)
109#endif
110
111// Global for test_ops_with_root.c
112struct cte *mdb_root = NULL;
113#if IN_KERNEL
114struct kcb *my_kcb = NULL;
115#endif
116static void set_root(struct cte *new_root)
117{
118    mdb_root = new_root;
119#if IN_KERNEL
120    my_kcb->mdb_root = (lvaddr_t) new_root;
121#endif
122}
123
124/*
125 * (re)initialization
126 */
127errval_t
128mdb_init(struct kcb *k)
129{
130    assert (k != NULL);
131#if IN_KERNEL
132#if 0
133    //XXX: write two versions of this; so we can have full sanity checks for
134    //all scenarios -SG
135    if (my_kcb) {
136        printf("MDB has non-null kcb.\n");
137        return CAPS_ERR_MDB_ALREADY_INITIALIZED;
138    }
139#endif
140    my_kcb = k;
141    if (!my_kcb->is_valid) {
142        // empty kcb, do nothing
143        return SYS_ERR_OK;
144    }
145#endif
146    // set root
147    mdb_root = (struct cte *)k->mdb_root;
148
149#if 0
150    // always check invariants here
151    int i = mdb_check_invariants();
152    if (i) {
153        printf("mdb invariant %s violated\n", mdb_invariant_to_str(i));
154        mdb_dump_all_the_things();
155        mdb_root = NULL;
156        return CAPS_ERR_MDB_INVARIANT_VIOLATION;
157    }
158#endif
159    return SYS_ERR_OK;
160}
161
162
163/*
164 * Debug printing.
165 */
166
167void
168mdb_dump_all_the_things(void)
169{
170    mdb_dump(mdb_root, 0);
171}
172
173STATIC_ASSERT(68 == ObjType_Num, "Knowledge of all cap types");
174static void print_cte(struct cte *cte, char *indent_buff)
175{
176    struct mdbnode *node = N(cte);
177    char extra[255] = { 0 };
178    struct capability *cap = C(cte);
179    switch (cap->type) {
180        case ObjType_L1CNode:
181            snprintf(extra, 255,
182                    "[allocated_bytes=%"PRIxGENSIZE",rightsmask=%"PRIu8"]",
183                    cap->u.l1cnode.allocated_bytes, cap->u.l1cnode.rightsmask);
184            break;
185        case ObjType_L2CNode:
186            snprintf(extra, 255, "[rightsmask=%"PRIu8"]", cap->u.l2cnode.rightsmask);
187            break;
188        case ObjType_EndPointLMP:
189            snprintf(extra, 255,
190                    "[listener=%p,epoffset=0x%08"PRIxLVADDR",epbuflen=%"PRIu32", if=%"PRIu32"]",
191                     cap->u.endpointlmp.listener,cap->u.endpointlmp.epoffset,
192                     cap->u.endpointlmp.epbuflen, cap->u.endpointlmp.iftype);
193            break;
194        case ObjType_EndPointUMP:
195            snprintf(extra, 255,
196                     "[base=%" PRIxLPADDR ",if=%"PRIu32"]",
197                     cap->u.endpointump.base, cap->u.endpointump.iftype);
198            break;
199        case ObjType_Dispatcher:
200            snprintf(extra, 255, "[dcb=%p]", cap->u.dispatcher.dcb);
201            break;
202        case ObjType_IO:
203            snprintf(extra, 255, "[start=0x%04"PRIx16",end=0x%04"PRIx16"]",
204                    cap->u.io.start, cap->u.io.end);
205            break;
206        default:
207            break;
208    }
209    printf("%s%p{left=%p,right=%p,end=0x%08"PRIxGENPADDR",end_root=%"PRIu8","
210            "level=%"PRIu8",address=0x%08"PRIxGENPADDR",size=0x%08"PRIx64","
211            "type=%"PRIu8",remote_rels=%d%d%d,extra=%s}\n",
212            indent_buff,
213            cte, node->left, node->right, node->end, node->end_root,
214            node->level, get_address(C(cte)), get_size(C(cte)),
215            (uint8_t)C(cte)->type, node->remote_copies,
216            node->remote_ancs, node->remote_descs,extra);
217    return;
218}
219
220void
221mdb_dump(struct cte *cte, int indent)
222{
223    // Print a tree with root on the left, the smallest element at the top and the
224    // largest at the bottom.
225
226    /* make an indent buffer */
227    char indent_buff[indent+2];
228    for (int i=0; i < indent+1; i++) {
229        indent_buff[i]='\t';
230    }
231    indent_buff[indent+1] = '\0';
232
233    if (!cte) {
234        printf("NULL{}\n");
235        return;
236    }
237
238    struct mdbnode *node = N(cte);
239    assert(node);
240
241    if (node->left) {
242        if (node->left == cte) {
243            printf("%sSELF!!!!\n", indent_buff);
244        }
245        else {
246            mdb_dump(node->left, indent+1);
247        }
248    }
249
250    print_cte(cte, indent_buff);
251
252    if (node->right) {
253        if (node->right == cte) {
254            printf("%sSELF!!!!\n", indent_buff);
255        }
256        else {
257            mdb_dump(node->right, indent+1);
258        }
259    }
260}
261
262
263/*
264 * Invariant checking.
265 */
266
267static int
268mdb_check_subtree_invariants(struct cte *cte)
269{
270    if (!cte) {
271        return MDB_INVARIANT_OK;
272    }
273    if (C(cte)->type == 0) {
274        mdb_dump_all_the_things();
275    }
276    assert(C(cte)->type != 0);
277
278    int err;
279    struct mdbnode *node = N(cte);
280
281    if (node->level > 0 && !(node->left && node->right)) {
282        MDB_RET_INVARIANT(cte, MDB_INVARIANT_BOTHCHILDREN);
283    }
284    if (node->left && !(N(node->left)->level < node->level)) {
285        MDB_RET_INVARIANT(cte, MDB_INVARIANT_LEFT_LEVEL_LESS);
286    }
287    if (node->right && !(N(node->right)->level <= node->level)) {
288        MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHT_LEVEL_LEQ);
289    }
290    if (node->right && N(node->right)->right &&
291        !(N(N(node->right)->right)->level < node->level))
292    {
293        MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHTRIGHT_LEVEL_LESS);
294    }
295    if (node->right && N(node->right)->left &&
296        !(N(N(node->right)->left)->level < node->level))
297    {
298        MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHTLEFT_LEVEL_LESS);
299    }
300
301    // build expected end root for current node
302    mdb_root_t expected_end_root = get_type_root(C(cte)->type);
303    if (node->left) {
304        expected_end_root = MAX(expected_end_root, N(node->left)->end_root);
305    }
306    if (node->right) {
307        expected_end_root = MAX(expected_end_root, N(node->right)->end_root);
308    }
309    if (node->end_root != expected_end_root) {
310        MDB_RET_INVARIANT(cte, MDB_INVARIANT_END_IS_MAX);
311    }
312
313    // build expected end for current node. this is complex because the root
314    // acts as an address prefix, so only ends where the corresponding root is
315    // the maximum may be considered.
316    genpaddr_t expected_end = 0;
317    if (get_type_root(C(cte)->type) == node->end_root) {
318        // only consider current cte's end if its root is node->end_root
319        expected_end = get_address(C(cte))+get_size(C(cte));
320    }
321    if (node->left && N(node->left)->end_root == node->end_root) {
322        // only consider left child end if its end_root is node->end_root
323        expected_end = MAX(expected_end, N(node->left)->end);
324    }
325    if (node->right && N(node->right)->end_root == node->end_root) {
326        // only consider right child end if its end_root is node->end_root
327        expected_end = MAX(expected_end, N(node->right)->end);
328    }
329    if (node->end != expected_end) {
330        MDB_RET_INVARIANT(cte, MDB_INVARIANT_END_IS_MAX);
331    }
332
333    if (node->left) {
334        assert(node->left != cte);
335        if (compare_caps(C(node->left), C(cte), true) >= 0) {
336            MDB_RET_INVARIANT(cte, MDB_INVARIANT_LEFT_SMALLER);
337        }
338        err = mdb_check_subtree_invariants(node->left);
339        if (err) {
340            return err;
341        }
342    }
343
344    if (node->right) {
345        assert(node->right != cte);
346        if (compare_caps(C(node->right), C(cte), true) <= 0) {
347            MDB_RET_INVARIANT(cte, MDB_INVARIANT_RIGHT_GREATER);
348        }
349        err = mdb_check_subtree_invariants(node->right);
350        if (err) {
351            return err;
352        }
353    }
354
355    return MDB_INVARIANT_OK;
356}
357
358int
359mdb_check_invariants(void)
360{
361    int res = mdb_check_subtree_invariants(mdb_root);
362    if (res != 0) {
363        printf("mdb_check_invariants() -> %d\n", res);
364    }
365    return res;
366}
367
368static bool
369mdb_is_reachable(struct cte *root, struct cte *cte)
370{
371    if (!root) {
372        return false;
373    }
374    if (root == cte) {
375        return true;
376    }
377    if (N(root)->left) {
378        if (mdb_is_reachable(N(root)->left, cte)) {
379            return true;
380        }
381    }
382    if (N(root)->right) {
383        if (mdb_is_reachable(N(root)->right, cte)) {
384            return true;
385        }
386    }
387    return false;
388}
389
390/*
391 * General internal helpers.
392 */
393
394static void
395mdb_update_end(struct cte *cte)
396{
397    if (!cte) {
398        return;
399    }
400    struct mdbnode *node = N(cte);
401
402    // build end root for current node
403    mdb_root_t end_root = get_type_root(C(cte)->type);
404    if (node->left) {
405        end_root = MAX(end_root, N(node->left)->end_root);
406    }
407    if (node->right) {
408        end_root = MAX(end_root, N(node->right)->end_root);
409    }
410    node->end_root = end_root;
411
412    // build end address for current node. this is complex because the root
413    // acts as an address prefix, so only ends where the corresponding root is
414    // the maximum may be considered.
415    genpaddr_t end = 0;
416    if (get_type_root(C(cte)->type) == node->end_root) {
417        // only consider current cte's end if its root is node->end_root
418        end = get_address(C(cte))+get_size(C(cte));
419    }
420    if (node->left && N(node->left)->end_root == node->end_root) {
421        // only consider left child end if its end_root is node->end_root
422        end = MAX(end, N(node->left)->end);
423    }
424    if (node->right && N(node->right)->end_root == node->end_root) {
425        // only consider right child end if its end_root is node->end_root
426        end = MAX(end, N(node->right)->end);
427    }
428    node->end = end;
429}
430
431static struct cte*
432mdb_skew(struct cte *node)
433{
434    /* transform invalid state
435     *
436     *               |
437     *      |L|<---|T|
438     *     /   \      \
439     *   |A|   |B|    |R|
440     *
441     * to valid equivalent state
442     *
443     *       |
444     *      |L|--->|T|
445     *     /      /   \
446     *   |A|    |B|   |R|
447     *
448     */
449    if (!node || !N(node)->left) {
450        return node;
451    }
452    else if (N(node)->level == N(N(node)->left)->level) {
453        struct cte *left = N(node)->left;
454        N(node)->left = N(left)->right;
455        N(left)->right = node;
456        mdb_update_end(node);
457        mdb_update_end(left);
458
459        // need to update mdb_root
460        if (mdb_root == node) {
461            set_root(left);
462        }
463        return left;
464    }
465    else {
466        return node;
467    }
468}
469
470static struct cte*
471mdb_split(struct cte *node)
472{
473    /* transform invalid state
474     *
475     *       |
476     *      |T|--->|R|-->|X|
477     *     /      /
478     *   |A|    |B|
479     *
480     * to valid equivalent state
481     *
482     *             |
483     *            |R|
484     *           /   \
485     *         |T|   |X|
486     *        /   \
487     *      |A|   |B|
488     *
489     */
490    if (!node || !N(node)->right || !N(N(node)->right)->right) {
491        return node;
492    }
493    else if (N(node)->level == N(N(N(node)->right)->right)->level) {
494        struct cte *right = N(node)->right;
495        N(node)->right = N(right)->left;
496        N(right)->left = node;
497        N(right)->level += 1;
498        mdb_update_end(node);
499        mdb_update_end(right);
500
501        // need to update mdb_root
502        if (mdb_root == node) {
503            set_root(right);
504        }
505        return right;
506    }
507    else {
508        return node;
509    }
510}
511
512static void
513mdb_decrease_level(struct cte *node)
514{
515    assert(node);
516
517    mdb_level_t expected;
518    if (!N(node)->left || !N(node)->right) {
519        expected = 0;
520    }
521    else {
522        expected = MIN(N(N(node)->left)->level, N(N(node)->right)->level) + 1;
523    }
524
525    if (expected < N(node)->level) {
526        N(node)->level = expected;
527        if (N(node)->right && expected < N(N(node)->right)->level) {
528            N(N(node)->right)->level = expected;
529        }
530    }
531}
532
533static struct cte*
534mdb_rebalance(struct cte *node)
535{
536    assert(node);
537    mdb_update_end(node);
538    mdb_decrease_level(node);
539    node = mdb_skew(node);
540    N(node)->right = mdb_skew(N(node)->right);
541    if (N(node)->right) {
542        N(N(node)->right)->right = mdb_skew(N(N(node)->right)->right);
543    }
544    node = mdb_split(node);
545    N(node)->right = mdb_split(N(node)->right);
546    return node;
547}
548
549#ifndef NDEBUG
550static bool
551mdb_is_child(struct cte *child, struct cte *parent)
552{
553    if (!parent) {
554        return mdb_root == child;
555    }
556    else {
557        return N(parent)->left == child || N(parent)->right == child;
558    }
559}
560#else
561#define mdb_is_child(a, b) 0
562#endif
563
564static bool
565mdb_is_inside(genpaddr_t outer_begin, genpaddr_t outer_end,
566              genpaddr_t inner_begin, genpaddr_t inner_end)
567{
568    assert(outer_begin <= outer_end);
569    assert(inner_begin <= inner_end);
570    return
571        (inner_begin >= outer_begin && inner_end < outer_end) ||
572        (inner_begin > outer_begin && inner_end <= outer_end);
573}
574
575/*
576 * Operations and operation-specific helpers.
577 */
578
579static errval_t
580mdb_sub_insert(struct cte *new_node, struct cte **current)
581{
582    errval_t err;
583    assert(new_node);
584    assert(current);
585    MDB_TRACE_ENTER(*current, "%p, %p (*%p)", new_node, *current, current);
586
587    struct cte *current_ = *current;
588
589    if (!current_) {
590        // we've reached an empty leaf, insert here
591        *current = new_node;
592        mdb_update_end(new_node);
593        return SYS_ERR_OK;
594    }
595
596    int compare = compare_caps(C(new_node), C(current_), true);
597    if (compare < 0) {
598        // new_node < current
599        err = mdb_sub_insert(new_node, &N(current_)->left);
600        if (err_is_fail(err)) {
601            return err;
602        }
603    }
604    else if (compare > 0) {
605        // new_node > current
606        err = mdb_sub_insert(new_node, &N(current_)->right);
607        if (err_is_fail(err)) {
608            return err;
609        }
610    }
611    else {
612        return CAPS_ERR_MDB_DUPLICATE_ENTRY;
613    }
614
615    mdb_update_end(current_);
616    current_ = mdb_skew(current_);
617    current_ = mdb_split(current_);
618    *current = current_;
619
620    err = SYS_ERR_OK;
621    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
622}
623
624uint64_t mdb_insert_count;
625errval_t
626mdb_insert(struct cte *new_node)
627{
628    MDB_TRACE_ENTER(mdb_root, "%p", new_node);
629#ifdef IN_KERNEL
630#ifdef MDB_TRACE_NO_RECURSIVE
631    char prefix[50];
632    snprintf(prefix, 50, "mdb_insert.%d: ", my_core_id);
633    print_cte(new_node, prefix);
634#endif
635#endif
636    mdb_insert_count ++;
637    errval_t ret = mdb_sub_insert(new_node, &mdb_root);
638    CHECK_INVARIANTS(mdb_root, new_node, true);
639    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, ret, mdb_root);
640}
641
642static void
643mdb_exchange_child(struct cte *first, struct cte *first_parent,
644                   struct cte *second)
645{
646    assert(mdb_is_child(first, first_parent));
647
648    if (!first_parent) {
649        set_root(second);
650    }
651    else if (N(first_parent)->left == first) {
652        N(first_parent)->left = second;
653    }
654    else if (N(first_parent)->right == first) {
655        N(first_parent)->right = second;
656    }
657    else {
658        assert(!"first is not child of first_parent");
659    }
660}
661
662static void
663mdb_exchange_nodes(struct cte *first, struct cte *first_parent,
664                   struct cte *second, struct cte *second_parent)
665{
666    struct cte *tmp_node;
667    mdb_level_t tmp_level;
668
669    mdb_exchange_child(first, first_parent, second);
670    mdb_exchange_child(second, second_parent, first);
671
672    tmp_node = N(first)->left;
673    N(first)->left = N(second)->left;
674    N(second)->left = tmp_node;
675
676    tmp_node = N(first)->right;
677    N(first)->right = N(second)->right;
678    N(second)->right = tmp_node;
679
680    tmp_level = N(first)->level;
681    N(first)->level = N(second)->level;
682    N(second)->level = tmp_level;
683
684    mdb_update_end(first);
685    mdb_update_end(second);
686
687    assert(mdb_is_reachable(mdb_root, first));
688    assert(mdb_is_reachable(mdb_root, second));
689}
690
691static void
692mdb_exchange_remove(struct cte *target, struct cte *target_parent,
693                    struct cte **current, struct cte *parent,
694                    int dir, struct cte **ret_target)
695{
696    assert(current);
697    MDB_TRACE_ENTER(*current, "%p, %p, %p (*%p), %p, %d", target, target_parent, *current, current, parent, dir);
698    assert(target);
699    assert(*current);
700    assert(parent);
701    assert(C(target)->type != 0);
702    assert(C(*current)->type != 0);
703    assert(C(parent)->type != 0);
704    assert(ret_target);
705    assert(!*ret_target);
706    assert(dir != 0);
707    assert(compare_caps(C(target), C(*current), true) != 0);
708    assert(mdb_is_child(target, target_parent));
709    assert(mdb_is_child(*current, parent));
710    assert(mdb_is_reachable(mdb_root, target));
711
712    struct cte *current_ = *current;
713
714    if (dir > 0) {
715        if (parent == target) {
716            assert(N(parent)->left == current_);
717        }
718        else {
719            assert(N(parent)->right == current_);
720        }
721
722        if (N(current_)->right) {
723            mdb_exchange_remove(target, target_parent, &N(current_)->right,
724                                current_, dir, ret_target);
725        }
726    }
727    else if (dir < 0) {
728        if (parent == target) {
729            assert(N(parent)->right == current_);
730        }
731        else {
732            assert(N(parent)->left == current_);
733        }
734
735        if (N(current_)->left) {
736            mdb_exchange_remove(target, target_parent, &N(current_)->left,
737                                current_, dir, ret_target);
738        }
739        else if (N(current_)->right) {
740            assert(N(current_)->level == 0);
741            // right is non-null, left null -> current is level 0 node with
742            // horizontal right link, and is also the successor of the target.
743            // in this case, exchange current and current right, then current
744            // (at its new position) and the target.
745            struct cte *new_current = N(current_)->right;
746            mdb_exchange_nodes(current_, parent, N(current_)->right, current_);
747            mdb_exchange_nodes(target, target_parent, current_, new_current);
748            // "current" is now located where the target was, further up in the
749            // tree. "new_current" is the node where current was. "target" is
750            // where current->right was, and is a leaf, so can be dropped.
751            assert(N(new_current)->right == target);
752            N(new_current)->right = NULL;
753            *ret_target = current_;
754            *current = new_current;
755            assert(!mdb_is_reachable(mdb_root, target));
756            MDB_TRACE_LEAVE_SUB(NULL);
757        }
758    }
759
760    if (*ret_target) {
761        assert(!mdb_is_reachable(mdb_root, target));
762        // implies we recursed further down to find a leaf. need to rebalance.
763        current_ = mdb_rebalance(current_);
764        *current = current_;
765        MDB_TRACE_LEAVE_SUB(current_);
766    }
767    else {
768        //printf("found leaf %p\n", current_);
769        // found successor/predecessor leaf, exchange with target
770        assert(!N(current_)->right && !N(current_)->left);
771        mdb_exchange_nodes(target, target_parent, current_, parent);
772
773        // "current" is now where target was, so set as ret_target
774        *ret_target = current_;
775        // target would be the new current, but we're removing it, so set
776        // current to null. This also sets parent's corresponding child to
777        // null by recursion.
778        *current = NULL;
779        MDB_TRACE_LEAVE_SUB(NULL);
780    }
781}
782
783static errval_t
784mdb_subtree_remove(struct cte *target, struct cte **current, struct cte *parent)
785{
786    assert(current);
787    MDB_TRACE_ENTER(*current, "%p, %p (*%p), %p", target, *current, current, parent);
788
789    errval_t err;
790    struct cte *current_ = *current;
791    if (!current_) {
792        err = CAPS_ERR_MDB_ENTRY_NOTFOUND;
793        MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
794    }
795
796    int compare = compare_caps(C(target), C(current_), true);
797    if (compare > 0) {
798        err = mdb_subtree_remove(target, &N(current_)->right, current_);
799        if (err != SYS_ERR_OK) {
800            MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
801            return err;
802        }
803    }
804    else if (compare < 0) {
805        err = mdb_subtree_remove(target, &N(current_)->left, current_);
806        if (err != SYS_ERR_OK) {
807            MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
808        }
809    }
810    else {
811        assert(current_ == target);
812        if (!N(current_)->left && !N(current_)->right) {
813            // target is leaf, just remove
814            *current = NULL;
815            err = SYS_ERR_OK;
816            MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, NULL);
817        }
818        else if (!N(current_)->left) {
819            // move to right child then go left (dir=-1)
820            // curr, new_right = xchg_rm(elem, parent, current.right, current, -1)
821            struct cte *new_current = NULL;
822            struct cte *new_right = N(current_)->right;
823            mdb_exchange_remove(target, parent, &new_right, current_, -1,
824                                &new_current);
825            assert(new_current);
826            current_ = new_current;
827            N(current_)->right = new_right;
828            assert(!mdb_is_reachable(mdb_root, target));
829        }
830        else {
831            // move to left child then go right (dir=1)
832            // curr, new_left = xchg_rm(elem, parent, current.left, current, 1)
833            struct cte *new_current = NULL;
834            struct cte *new_left = N(current_)->left;
835            mdb_exchange_remove(target, parent, &new_left, current_, 1,
836                                &new_current);
837            assert(new_current);
838            current_ = new_current;
839            N(current_)->left = new_left;
840            assert(!mdb_is_reachable(mdb_root, target));
841        }
842    }
843
844    // rebalance after remove from subtree
845    current_ = mdb_rebalance(current_);
846    *current = current_;
847
848    assert(C(target)->type != 0);
849    assert(!*current || C(*current)->type != 0);
850
851    err = SYS_ERR_OK;
852    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, current_);
853}
854
855uint64_t mdb_remove_count;
856errval_t
857mdb_remove(struct cte *target)
858{
859    MDB_TRACE_ENTER(mdb_root, "%p", target);
860    CHECK_INVARIANTS(mdb_root, target, true);
861#ifdef IN_KERNEL
862#ifdef MDB_TRACE_NO_RECURSIVE
863    char prefix[50];
864    snprintf(prefix, 50, "mdb_remove.%d: ", my_core_id);
865    print_cte(target, prefix);
866#endif
867#endif
868    mdb_remove_count++;
869    errval_t err = mdb_subtree_remove(target, &mdb_root, NULL);
870    CHECK_INVARIANTS(mdb_root, target, false);
871    MDB_TRACE_LEAVE_SUB_RET("%"PRIuPTR, err, mdb_root);
872}
873
874/*
875 * Queries on the ordering.
876 */
877
878static struct cte*
879mdb_sub_find_equal(struct capability *cap, struct cte *current)
880{
881    if (!current) {
882        return NULL;
883    }
884    int compare = compare_caps(cap, C(current), false);
885    if (compare < 0) {
886        // current is gt key, look for smaller node
887        return mdb_sub_find_equal(cap, N(current)->left);
888    }
889    else if (compare > 0) {
890        // current is lt key, attempt to find bigger current
891        return mdb_sub_find_equal(cap, N(current)->right);
892    }
893    else {
894        return current;
895    }
896}
897
898uint64_t mdb_find_equal_count;
899struct cte*
900mdb_find_equal(struct capability *cap)
901{
902    mdb_find_equal_count++;
903    return mdb_sub_find_equal(cap, mdb_root);
904}
905
906static struct cte*
907mdb_sub_find_less(struct capability *cap, struct cte *current, bool equal_ok,
908                  bool tiebreak)
909{
910    if (!current) {
911        return NULL;
912    }
913    int compare = compare_caps(cap, C(current), tiebreak);
914    if (compare < 0) {
915        // current is gt key, look for smaller node
916        return mdb_sub_find_less(cap, N(current)->left, equal_ok, tiebreak);
917    }
918    else if (compare > 0) {
919        // current is lt key, attempt to find bigger current
920        struct cte *res = mdb_sub_find_less(cap, N(current)->right, equal_ok,
921                                            tiebreak);
922        if (res) {
923            return res;
924        }
925        // bigger child exceeded key
926        return current;
927    }
928    else {
929        // found equal element
930        if (equal_ok) {
931            return current;
932        }
933        else {
934            // look for smaller node
935            return mdb_sub_find_less(cap, N(current)->left, equal_ok,
936                                     tiebreak);
937        }
938    }
939}
940
941uint64_t mdb_find_less_count;
942struct cte*
943mdb_find_less(struct capability *cap, bool equal_ok)
944{
945    mdb_find_less_count++;
946    return mdb_sub_find_less(cap, mdb_root, equal_ok, false);
947}
948
949static struct cte*
950mdb_sub_find_greater(struct capability *cap, struct cte *current,
951                     bool equal_ok, bool tiebreak)
952{
953    if (!current) {
954        return NULL;
955    }
956    int compare = compare_caps(cap, C(current), tiebreak);
957    if (compare < 0) {
958        // current is gt key, attempt to find smaller node
959        struct cte *res = mdb_sub_find_greater(cap, N(current)->left, equal_ok,
960                                               tiebreak);
961        if (res) {
962            return res;
963        }
964        // smaller was lte key
965        return current;
966    }
967    else if (compare > 0) {
968        // current is lte key, look for greater node
969        return mdb_sub_find_greater(cap, N(current)->right, equal_ok,
970                                    tiebreak);
971    }
972    else {
973        // found equal element
974        if (equal_ok) {
975            return current;
976        }
977        else {
978            // look for greater node
979            return mdb_sub_find_greater(cap, N(current)->right, equal_ok,
980                                        tiebreak);
981        }
982    }
983}
984
985uint64_t mdb_find_greater_count;
986struct cte*
987mdb_find_greater(struct capability *cap, bool equal_ok)
988{
989    mdb_find_greater_count++;
990    return mdb_sub_find_greater(cap, mdb_root, equal_ok, false);
991}
992
993uint64_t mdb_predecessor_count;
994struct cte*
995mdb_predecessor(struct cte *current)
996{
997    mdb_predecessor_count++;
998    struct mdbnode *node = N(current);
999    if (node->left) {
1000        // if possible, look just at children
1001        current = node->left;
1002        while ((node = N(current))->right) {
1003            current = node->right;
1004        }
1005        return current;
1006    }
1007    // XXX: in lieu of a parent pointer that can be used to traverse upwards,
1008    // we have to perform a search through the tree from the root. This makes
1009    // "predecessor" into a O(log(n)) operation, instead of the expected O(1).
1010    return mdb_sub_find_less(C(current), mdb_root, false, true);
1011}
1012
1013uint64_t mdb_successor_count;
1014struct cte*
1015mdb_successor(struct cte *current)
1016{
1017    mdb_successor_count++;
1018    struct mdbnode *node = N(current);
1019    if (node->right) {
1020        // if possible, look just at children
1021        current = node->right;
1022        while ((node = N(current))->left) {
1023            current = node->left;
1024        }
1025        return current;
1026    }
1027    // XXX: in lieu of a parent pointer that can be used to traverse upwards,
1028    // we have perform a search through the tree from the root. This makes
1029    // "successor" into a O(log(n)) operation, instead of the expected O(1).
1030    return mdb_sub_find_greater(C(current), mdb_root, false, true);
1031}
1032
1033/*
1034 * The range query.
1035 */
1036
1037static struct cte*
1038mdb_choose_surrounding(genpaddr_t address, size_t size, struct cte *first,
1039                       struct cte *second)
1040{
1041    assert(first);
1042    assert(second);
1043    assert(get_type_root(C(first)->type) == get_type_root(C(second)->type));
1044#ifndef NDEBUG
1045    genpaddr_t beg = address, end = address + size;
1046    genpaddr_t fst_beg = get_address(C(first));
1047    genpaddr_t snd_beg = get_address(C(second));
1048    genpaddr_t fst_end = fst_beg + get_size(C(first));
1049    genpaddr_t snd_end = snd_beg + get_size(C(second));
1050    assert(fst_beg <= beg && fst_end >= end);
1051    assert(snd_beg <= beg && snd_end >= end);
1052#endif
1053
1054    if (compare_caps(C(first), C(second), true) >= 0) {
1055        return first;
1056    }
1057    else {
1058        return second;
1059    }
1060}
1061
1062static struct cte*
1063mdb_choose_inner(genpaddr_t address, size_t size, struct cte *first,
1064                 struct cte *second)
1065{
1066    assert(first);
1067    assert(second);
1068    assert(get_type_root(C(first)->type) == get_type_root(C(second)->type));
1069#ifndef NDEBUG
1070    genpaddr_t end = address + size;
1071    genpaddr_t fst_beg = get_address(C(first));
1072    genpaddr_t snd_beg = get_address(C(second));
1073    genpaddr_t fst_end = fst_beg + get_size(C(first));
1074    genpaddr_t snd_end = snd_beg + get_size(C(second));
1075    assert(mdb_is_inside(address, end, fst_beg, fst_end));
1076    assert(mdb_is_inside(address, end, snd_beg, snd_end));
1077#endif
1078
1079    if (compare_caps(C(first), C(second), true) <= 0) {
1080        return first;
1081    }
1082    else {
1083        return second;
1084    }
1085}
1086
1087static struct cte*
1088mdb_choose_partial(genpaddr_t address, size_t size, struct cte *first,
1089                   struct cte *second)
1090{
1091    assert(first);
1092    assert(second);
1093    assert(get_type_root(C(first)->type) == get_type_root(C(second)->type));
1094    genpaddr_t beg = address;
1095    genpaddr_t fst_beg = get_address(C(first));
1096    genpaddr_t snd_beg = get_address(C(second));
1097#ifndef NDEBUG
1098    genpaddr_t end = address + size;
1099    genpaddr_t fst_end = fst_beg + get_size(C(first));
1100    genpaddr_t snd_end = snd_beg + get_size(C(second));
1101    assert(fst_beg < end);
1102    assert(snd_beg < end);
1103    assert(fst_end > beg);
1104    assert(snd_end > beg);
1105    assert(fst_beg != beg);
1106    assert(snd_beg != beg);
1107    assert(fst_end != end);
1108    assert(snd_end != end);
1109    assert((fst_beg < beg) == (fst_end < end));
1110    assert((snd_beg < beg) == (snd_end < end));
1111#endif
1112
1113    if (fst_beg < beg && snd_beg > beg) {
1114        return first;
1115    }
1116    else if (snd_beg < beg && fst_beg > beg) {
1117        return second;
1118    }
1119    else {
1120        if (compare_caps(C(first), C(second), true) >= 0) {
1121            return first;
1122        }
1123        else {
1124            return second;
1125        }
1126    }
1127}
1128
1129static int
1130mdb_sub_find_range(mdb_root_t root, genpaddr_t address, size_t size,
1131                   int max_precision, struct cte *current,
1132                   /*out*/ struct cte **ret_node);
1133
1134static void
1135mdb_sub_find_range_merge(mdb_root_t root, genpaddr_t address, size_t size,
1136                         int max_precision, struct cte *sub,
1137                         /*inout*/ int *ret, /*inout*/ struct cte **result)
1138{
1139    assert(sub);
1140    assert(ret);
1141    assert(result);
1142    assert(max_precision >= 0);
1143    assert(*ret <= max_precision);
1144
1145    struct cte *sub_result = NULL;
1146    int sub_ret = mdb_sub_find_range(root, address, size, max_precision, sub,
1147                                     &sub_result);
1148    if (sub_ret > max_precision) {
1149        *result = sub_result;
1150        *ret = sub_ret;
1151    }
1152    else if (sub_ret > *ret) {
1153        *result = sub_result;
1154        *ret = sub_ret;
1155    }
1156    else if (sub_ret == *ret) {
1157        switch (sub_ret) {
1158        case MDB_RANGE_NOT_FOUND:
1159            break;
1160        case MDB_RANGE_FOUND_SURROUNDING:
1161            *result = mdb_choose_surrounding(address, size, *result, sub_result);
1162            break;
1163        case MDB_RANGE_FOUND_INNER:
1164            *result = mdb_choose_inner(address, size, *result, sub_result);
1165            break;
1166        case MDB_RANGE_FOUND_PARTIAL:
1167            *result = mdb_choose_partial(address, size, *result, sub_result);
1168            break;
1169        default:
1170            assert(!"Unhandled enum value for mdb_find_range result");
1171            break;
1172        }
1173    }
1174    // else ret > sub_ret, keep ret & result as is
1175}
1176
1177static int
1178mdb_sub_find_range(mdb_root_t root, genpaddr_t address, size_t size,
1179                   int max_precision, struct cte *current,
1180                   /*out*/ struct cte **ret_node)
1181{
1182    assert(max_precision >= 0);
1183    assert(ret_node);
1184
1185    if (!current) {
1186        *ret_node = NULL;
1187        return MDB_RANGE_NOT_FOUND;
1188    }
1189
1190    if (N(current)->end_root < root) {
1191        *ret_node = NULL;
1192        return MDB_RANGE_NOT_FOUND;
1193    }
1194    if (N(current)->end_root == root && N(current)->end <= address) {
1195        *ret_node = NULL;
1196        return MDB_RANGE_NOT_FOUND;
1197    }
1198
1199    mdb_root_t current_root = get_type_root(C(current)->type);
1200
1201    struct cte *result = NULL;
1202    int ret = MDB_RANGE_NOT_FOUND;
1203
1204    genpaddr_t current_address = get_address(C(current));
1205    genpaddr_t current_end = current_address + get_size(C(current));
1206    genpaddr_t search_end = address + size;
1207
1208    if (current_root == root) {
1209
1210        if (ret < MDB_RANGE_FOUND_PARTIAL &&
1211            current_address > address &&
1212            current_address < search_end &&
1213            current_end > search_end)
1214        {
1215            result = current;
1216            ret = MDB_RANGE_FOUND_PARTIAL;
1217        }
1218        if (ret < MDB_RANGE_FOUND_PARTIAL &&
1219            current_end > address &&
1220            current_end < search_end &&
1221            current_address < address)
1222        {
1223            result = current;
1224            ret = MDB_RANGE_FOUND_PARTIAL;
1225        }
1226        if (ret < MDB_RANGE_FOUND_INNER &&
1227            mdb_is_inside(address, search_end, current_address, current_end))
1228        {
1229            result = current;
1230            ret = MDB_RANGE_FOUND_INNER;
1231        }
1232        if (ret < MDB_RANGE_FOUND_SURROUNDING &&
1233            current_address <= address &&
1234            // exclude 0-length match with curaddr==addr
1235            current_address < search_end &&
1236            current_end >= search_end &&
1237            // exclude 0-length match with currend==addr
1238            current_end > address)
1239        {
1240            result = current;
1241            ret = MDB_RANGE_FOUND_SURROUNDING;
1242        }
1243        if (ret > max_precision) {
1244            *ret_node = result;
1245            return ret;
1246        }
1247    }
1248
1249    if (N(current)->left) {
1250        mdb_sub_find_range_merge(root, address, size, max_precision,
1251                                 N(current)->left, /*inout*/&ret,
1252                                 /*inout*/&result);
1253        if (ret > max_precision) {
1254            *ret_node = result;
1255            return ret;
1256        }
1257    }
1258
1259    if (N(current)->right && root >= current_root &&
1260        (search_end > current_address || (search_end == current_address && size == 0))) {
1261        mdb_sub_find_range_merge(root, address, size, max_precision,
1262                                 N(current)->right, /*inout*/&ret,
1263                                 /*inout*/&result);
1264        if (ret > max_precision) {
1265            *ret_node = result;
1266            return ret;
1267        }
1268    }
1269
1270    *ret_node = result;
1271    return ret;
1272
1273}
1274
1275uint64_t mdb_find_range_count;
1276errval_t
1277mdb_find_range(mdb_root_t root, genpaddr_t address, gensize_t size,
1278               int max_result, /*out*/ struct cte **ret_node,
1279               /*out*/ int *result)
1280{
1281    mdb_find_range_count++;
1282    if (max_result < MDB_RANGE_NOT_FOUND ||
1283        max_result > MDB_RANGE_FOUND_PARTIAL)
1284    {
1285        return CAPS_ERR_INVALID_ARGS;
1286    }
1287    if (max_result > MDB_RANGE_NOT_FOUND && !ret_node) {
1288        return CAPS_ERR_INVALID_ARGS;
1289    }
1290    if (!result) {
1291        return CAPS_ERR_INVALID_ARGS;
1292    }
1293
1294    struct cte *alt_ret_node;
1295    if (!ret_node) {
1296        ret_node = &alt_ret_node;
1297    }
1298
1299    *result = mdb_sub_find_range(root, address, size, max_result, mdb_root, ret_node);
1300    return SYS_ERR_OK;
1301}
1302
1303uint64_t mdb_find_cap_for_address_count;
1304errval_t
1305mdb_find_cap_for_address(genpaddr_t address, struct cte **ret_node)
1306{
1307    mdb_find_cap_for_address_count++;
1308    int result;
1309    errval_t err;
1310    // query for size 1 to get the smallest cap that includes the byte at the
1311    // given address
1312    err = mdb_find_range(get_type_root(ObjType_RAM), address,
1313                         1, MDB_RANGE_FOUND_SURROUNDING, ret_node, &result);
1314    if (err_is_fail(err)) {
1315        return err;
1316    }
1317    if (result != MDB_RANGE_FOUND_SURROUNDING) {
1318        return SYS_ERR_CAP_NOT_FOUND;
1319    }
1320    return SYS_ERR_OK;
1321}
1322
1323bool mdb_reachable(struct cte *cte)
1324{
1325    return mdb_is_reachable(mdb_root, cte);
1326}
1327
1328errval_t
1329mdb_traverse(enum mdb_tree_traversal_order order, mdb_tree_traversal_fn cb, void *data)
1330{
1331    return mdb_traverse_subtree(mdb_root, order, cb, data);
1332}
1333
1334errval_t
1335mdb_traverse_subtree(struct cte *cte, enum mdb_tree_traversal_order order,
1336        mdb_tree_traversal_fn cb, void *data)
1337{
1338    struct mdbnode *node = N(cte);
1339    assert(node);
1340
1341    struct cte *first, *second;
1342    if (order == MDB_TRAVERSAL_ORDER_ASCENDING) {
1343        first = node->left;
1344        second = node->right;
1345    } else {
1346        first = node->right;
1347        second = node->left;
1348    }
1349
1350    errval_t err;
1351
1352    if (first) {
1353        err = mdb_traverse_subtree(first, order, cb, data);
1354        if (err_is_fail(err)) {
1355            return err;
1356        }
1357    }
1358
1359    err = cb(cte, data);
1360
1361    if (err_is_fail(err)) {
1362        return err;
1363    }
1364
1365    if (second) {
1366        err = mdb_traverse_subtree(second, order, cb, data);
1367        if (err_is_fail(err)) {
1368            return err;
1369        }
1370    }
1371    return SYS_ERR_OK;
1372}
1373
1374static errval_t count_nodes(struct cte *cte, void *data)
1375{
1376    size_t *counter = data;
1377    *counter = *counter + 1;
1378    return SYS_ERR_OK;
1379}
1380
1381errval_t
1382mdb_size(size_t *count)
1383{
1384    return mdb_traverse(MDB_TRAVERSAL_ORDER_ASCENDING, count_nodes, count);
1385}
1386