summaryrefslogblamecommitdiffstats
path: root/klib/rbtree.c
blob: 29203841ccf5bf957dc6fa1948c2ae2c2330c481 (plain) (tree)








































































































































































































































































































































































                                                                                
#include "rbtree.h"
#include "assert.h"
#include "string.h"

#define BLACK 0
#define RED 1

#define LEFT 0
#define left children[LEFT]

#define RIGHT 1
#define right children[RIGHT]

#define CHILD_DIR(c) (c->parent->children[LEFT] == c ? LEFT : RIGHT)

/**Given a root and a direction, rotate the subtree such that the order of the
 * elements are unchanged.
 * The opposite direction child of the given root is treated as the pivot point
 * The dir-most child of the pivot becomes the opposite dir child of the root.
 * The pivot becomes the root and vice-versa.
 * @param tree to do rotation on.
 * @param root of rotation.
 * @param dir to rotate by.
 * @return new root node.*/
rbnode_t*
s_node_rotate(rbtree_t *tree, rbnode_t *root, bool dir)
{
    assert(root != NULL);
    rbnode_t *gparent = root->parent;
    rbnode_t *sibling = root->children[1-dir];
    rbnode_t *child = NULL;

    assert(sibling != NULL);
    child = sibling->children[dir];

    /* Opposite child of root is dir-most child of sibling.*/
    root->children[1-dir] = child;
    if(child != NULL) child->parent = root;

    /* Child of sibling is the root.*/
    sibling->children[  dir] = root;
    root->parent = sibling;

    /* Parent of sibling is the grandparent. */
    sibling->parent = gparent;
    if(gparent != NULL)
        gparent->children[CHILD_DIR(root)] = sibling;
    else
        tree->root = sibling;

    return sibling;
}

/**Search the rbtree for a node that matches the key.
 * If no such node is found, returns NULL.
 * @param tree to search through.
 * @param key to search for.
 * @return found node.*/
rbnode_t*
s_find(rbtree_t *tree, intmax_t key)
{
    rbnode_t *node = tree->root;
    while(node != NULL && node->key != key) {
        node = node->children[node->key > key];
    }
    return node;
}

rbnode_t*
s_closest(rbtree_t *tree, intmax_t key)
{
    rbnode_t *node = tree->root;
    while(node != NULL && node->key != key) {
        rbnode_t *child = node->children[node->key > key];
        if(child == NULL) return node;
        node = child;
    }
    return node;
}

/**Find the left-most node for this subtree.
 * @param node root of the subtree to search through.
 * @return left-most node.*/
rbnode_t*
s_node_min(rbnode_t *node)
{
    while(node->left != NULL) node = node->left;
    return node;
}

/**Find the successor for the specified node.
 * The successor is defined as the left-most node of the right sub-tree.
 * (the smallest node larger than the root)
 * @param node to search for.
 * @return the successor.*/
rbnode_t*
s_node_successor(rbnode_t *node)
{
    if(node->right != NULL)
        return s_node_min(node->right);
    /* If the node has no left child (meaning that the node is larger than it's
     * parents), we cycle up through the RHS of the parent tree (such that each
     * parent is smaller than the child) until we reach a point by which the 
     * child node is smaller than the parent.*/
    rbnode_t *parent = node->parent;
    while(parent != NULL && node == parent->right) {
        node = parent;
        parent = node->parent;
    }
    return parent;
}

/**Copy the data from one node onto another, keeping parent and child relations.
 * @prarm tree for metadata purposes.
 * @param dest destination node.
 * @param src source node.*/
void
s_node_copy_value(rbtree_t *tree, rbnode_t *dest, rbnode_t *src)
{
    dest->key = src->key;
    memcpy(dest->data, src->data, tree->objw);
}

/**Perform tree adjustments after insertion to rebalance the rbtree.
 * @param tree to rebalance.
 * @param node that was inserted.
 * @param parent of the node.
 * @param dir that the node was inserted in.*/
void
s_node_insert_fix(
        rbtree_t *tree,
        rbnode_t *node,
        rbnode_t *parent,
        bool dir)
{
    rbnode_t *gparent;
    rbnode_t *uncle;
    while((parent = node->parent) != NULL) {
        if(parent->color == BLACK) return; /* Tree is valid. no fixes needed.*/
        /*Parent is known to be red.*/
        if((gparent = parent->parent) == NULL) {
            /*Parent is red and the root. We can just fix it and return.*/
            parent->color = BLACK;
            return;
        }
        /* Parent is red and has a parent.
         * We now need to fix the parent.*/
        dir = CHILD_DIR(parent);
        uncle = gparent->children[1-dir];
        if(uncle == NULL || uncle->color == BLACK) {
            /* Parent is red, but uncle (opposite child of gparent) is black.
             * We need to rotate the gparent node such that the parent takes
             * its place. However, if the current node is an inner child of 
             * gparent, we need to rotate N into P.*/
            if(node == parent->children[1-dir]) {
                /* Node is an inner child. We need to rotate N and P first.*/
                s_node_rotate(tree, parent, dir);
                node = parent;
                parent = gparent->children[dir];
            }
            /* N is not an inner child, we can rotate the tree properly.*/
            s_node_rotate(tree, gparent, 1-dir);
            parent->color = BLACK;
            gparent->color = RED;
            return;
        }
    }
}

/**Insert a node into the rbtree on the parent and direction given.
 * @param tree to insert into.
 * @param node to insert.
 * @param parent node to insert onto.
 * @param dir direction to insert node.*/
void
s_node_insert(rbtree_t *tree, rbnode_t *node, rbnode_t *parent, bool dir)
{
    node->color = RED;
    node->left = NULL;
    node->right = NULL;
    node->parent = parent;
    if(parent == NULL) {
        tree->root = node;
        return;
    }
    parent->children[dir] = node;
    s_node_insert_fix(tree, node, parent, dir);
}

void
s_node_remove_fixup(rbtree_t *tree, rbnode_t *node)
{
    bool dir;
    rbnode_t *parent = node->parent;
    rbnode_t *sibling;
    rbnode_t *cnephew;
    rbnode_t *dnephew;

    dir = CHILD_DIR(node);
    parent->children[dir] = NULL;
    while((parent = node->parent) != NULL)
    {
        dir = CHILD_DIR(node);
        sibling = parent->children[1-dir];
        cnephew = parent->children[  dir];
        dnephew = parent->children[1-dir];
        if(sibling->color == RED)
            goto case3;
        if(dnephew != NULL && dnephew->color == RED)
            goto case6;
        if(cnephew != NULL && cnephew->color == RED)
            goto case5;
        if(parent->color == RED)
            goto case4;

        /* The parent, both nephews, and sibling are black. We need to paint
         * the sibling black and move up a level.*/
        sibling->color = RED;
        node = parent;
    }
    return;

case3:
    /* The sibling is red, which means both nephews and the parent are black.
     * We rotate the subtree so that the parent moves to become the siblings
     * left child. The tree is then repainted.*/
    s_node_rotate(tree, parent, dir);
    parent->color = RED;
    sibling->color = BLACK;
    /* Now we are working on the sibling subtree.*/
    sibling = cnephew;
    dnephew = sibling->children[1-dir];
    if(dnephew != NULL && dnephew->color == BLACK)
        goto case6;
    cnephew = sibling->children[dir];
    if(cnephew != NULL && cnephew->color == RED)
        goto case5;

case4:
    sibling->color = RED;
    parent->color = BLACK;
    return;

case5:
    s_node_rotate(tree, sibling, 1-dir);
    sibling->color = RED;
    cnephew->color = BLACK;
    dnephew = sibling;
    sibling = cnephew;

case6:
    s_node_rotate(tree, parent, dir);
    sibling->color = parent->color;
    parent->color = BLACK;
    dnephew->color = BLACK;
}

void
s_node_remove(rbtree_t *tree, rbnode_t *node)
{
    rbnode_t *parent = node->parent;
    if(node->left != NULL && node->right != NULL) {
        /* Node is a full tree. We can replace the node with it's successor to
         * keep the balance.
         * Needs to recurse if the successor has a right-most child, which is
         * dealt with in the next case.*/
        rbnode_t *successor = s_node_successor(node);
        s_node_copy_value(tree, node, successor);
        s_node_remove(tree, successor);
        return;
    } else if(node->left == NULL && node->right == NULL) {
        if(parent == NULL) {
            tree->root = NULL;
            kfree(node);
            return;
        }
        if(node->color == RED) {
            parent->children[CHILD_DIR(node)] = NULL;
            kfree(node);
            return;
        }
        s_node_remove_fixup(tree, node);
    } else {
        /* Node has a single child. Because the child must be red as per spec,
         * and this node must be black, we can simply replace the node
         * with it's child and color it black.*/
        rbnode_t *child = node->left == NULL ? node->right : node->left;
        memcpy(node, child, sizeof(rbnode_t) + tree->objw);
        node->parent = parent;
        node->color = BLACK;
        kfree(child);
        return;
    }
}

void
__rbtree_new(rbtree_t *tree, size_t objw)
{
    *tree = (rbtree_t) {
        .root = NULL,
        .objw = objw,
    };
}

void*
rbtree_find(rbtree_t *tree, intmax_t key)
{
    rbnode_t *found = s_find(tree, key);
    if(found == NULL) return NULL;
    return found->data;
}

rbnode_t*
s_node_new(rbtree_t *tree, intmax_t key, void *data)
{
    rbnode_t *node = kmalloc(sizeof(rbnode_t) + tree->objw);
    if(data != NULL) memcpy(node->data, data, tree->objw);
    node->key = key;
    return node;
}

void*
rbtree_insert(rbtree_t *tree, intmax_t key, void *data)
{
    assert(tree != NULL);

    rbnode_t *node = s_closest(tree, key);
    if(node != NULL && node->key == key) {
        memcpy(node->data, data, tree->objw);
    } else if(node != NULL) {
        rbnode_t *child = s_node_new(tree, key, data);
        s_node_insert(tree, child, node, node->key > key);
        return child->data;
    } else {
        node = s_node_new(tree, key, data);
        node->parent = NULL;
        node->children[0] = node->children[1] = NULL;
        node->color = BLACK;
        tree->root = node;
    }
    return node->data;
}

void *rbtree_reserve(rbtree_t *tree, intmax_t key)
{
    assert(tree != NULL);

    rbnode_t *node = s_closest(tree, key);
    if(node != NULL && node->key == key) {
        return node->data;
    } else if(node != NULL) {
        rbnode_t *child = s_node_new(tree, key, NULL);
        s_node_insert(tree, child, node, node->key > key);
        return child->data;
    } else {
        node = s_node_new(tree, key, NULL);
        tree->root = node;
    }
    return node->data;
}