Skip to main content

A Generic Red-Black Tree Implementation In Golang

To balance the tree CLRS implementation rotates the subtrees. The functional implementations rewrites the tree [1,2]. Thinking in terms of rewriting instead of rotating feels like looking at the problem in a proper coordinate system. My implementation here is translation of Matt Might’s Haskell code in go.

The generics syntax is easy on eyes. I like it.

References

  1. The missing method: Deleting from Okasaki’s red-black trees by Matt Might
  2. Red-Black Trees in a Functional Setting by Okasaki
  3. The Next Step for Generics by Ian Lance Taylor and Robert Griesemer

Code

Playground link

// A generic red-black tree implementation from the
// functional implementation by Matt Might[1] and Okasaki.
//
// [1] http://matt.might.net/articles/red-black-delete/
//
// Author: Pratik Deoghare


package main

import (
    "fmt"
)

type Map[K, V any] interface {
    Get(key K) (value V, ok bool)
    Set(key K, value V)
    Delete(key K)
}

func New[K, V any](less func(K, K) bool) Map[K, V] {
    leaf := &node[K, V]{
        color: B,
    }
    leaf.a = leaf
    leaf.b = leaf
    bbleaf := &node[K, V]{
        color: BB,
    }
    bbleaf.a = leaf
    bbleaf.b = leaf
    return &rbmap[K, V]{
        less:   less,
        leaf:   leaf,
        bbleaf: bbleaf,
        root:   leaf,
    }
}

type color uint8

const (
    R  color = 0
    B  color = 1
    BB color = 2
    NB color = 3
)

type node[K, V any] struct {
    color color
    key   K
    value V
    a     *node[K, V]
    b     *node[K, V]
}
type rbmap[K, V any] struct {
    root   *node[K, V]
    leaf   *node[K, V] // the leaf always Black. We don't touch it. Its a sacred leaf.
    bbleaf *node[K, V] // this is used for deletion
    less   func(K, K) bool
}

func (r rbmap[K, V]) Preorder() {
    r.preorder(r.root, "")
}
func (r rbmap[K, V]) preorder(n *node[K, V], tab string) {
    if n == r.leaf {
        return
    }
    fmt.Println(tab, n.key, "=>", n.value, n.color)
    r.preorder(n.a, ":"+tab)
    r.preorder(n.b, ":"+tab)
}
func (r rbmap[K, V]) Inorder() {
    panic("implement me")
}
func (r rbmap[K, V]) Get(key K) (value V, ok bool) {
    n := r.root
    for n != r.leaf {
        if r.less(key, n.key) {
            n = n.a
        } else if r.less(n.key, key) {
            n = n.b
        } else {
            return n.value, true
        }
    }
    return Nil[V](), false
}
func Nil[T any]() T {
    var zero T
    return zero
}
func (r *rbmap[K, V]) Set(key K, value V) {
    r.root = blacken(r.insert(r.root, key, value))
}
func blacken[K, V any](n *node[K, V]) *node[K, V] {
    n.color = B
    return n
}
func redden[K, V any](n *node[K, V]) *node[K, V] {
    n.color = R
    return n
}
func (r *rbmap[K, V]) insert(n *node[K, V], key K, value V) *node[K, V] {
    if n == r.leaf {
        return &node[K, V]{
            color: R,
            key:   key,
            value: value,
            a:     r.leaf,
            b:     r.leaf,
        }
    }
    if r.less(key, n.key) {
        n.a = r.insert(n.a, key, value)
        n = balance(n)
    } else if r.less(n.key, key) {
        n.b = r.insert(n.b, key, value)
        n = balance(n)
    } else {
        n.value = value
    }
    return n
}
func colors[K, V any](n1, n2, n3 *node[K, V], c1, c2, c3 color) bool {
    return n1.color == c1 && n2.color == c2 && n3.color == c3
}
func balance[K, V any](n *node[K, V]) *node[K, V] {
    var x, y, z *node[K, V]
    var a, b, c, d *node[K, V]
    okasakiCase := false
    switch {
    case colors(n, n.a, n.a.a, B, R, R):
        x, y, z = n.a.a, n.a, n
        a, b, c, d = x.a, x.b, y.b, z.b
        okasakiCase = true
    case colors(n, n.a, n.a.b, B, R, R):
        x, y, z = n.a, n.a.b, n
        a, b, c, d = x.a, y.a, y.b, z.b
        okasakiCase = true
    case colors(n, n.b, n.b.a, B, R, R):
        x, y, z = n, n.b.a, n.b
        a, b, c, d = x.a, y.a, y.b, z.b
        okasakiCase = true
    case colors(n, n.b, n.b.b, B, R, R):
        x, y, z = n, n.b, n.b.b
        a, b, c, d = x.a, y.a, z.a, z.b
        okasakiCase = true
    }
    if okasakiCase {
        x.a, x.b, z.a, z.b = a, b, c, d
        y.a, y.b = x, z
        x.color, y.color, z.color = B, R, B
        return y
    }
    mightCase := false
    switch {
    case colors(n, n.a, n.a.a, BB, R, R):
        x, y, z = n.a.a, n.a, n
        a, b, c, d = x.a, x.b, y.b, z.b
        mightCase = true
    case colors(n, n.a, n.a.b, BB, R, R):
        x, y, z = n.a, n.a.b, n
        a, b, c, d = x.a, y.a, y.b, z.b
        mightCase = true
    case colors(n, n.b, n.b.a, BB, R, R):
        x, y, z = n, n.b.a, n.b
        a, b, c, d = x.a, y.a, y.b, z.b
        mightCase = true
    case colors(n, n.b, n.b.b, BB, R, R):
        x, y, z = n, n.b, n.b.b
        a, b, c, d = x.a, y.a, z.a, z.b
        mightCase = true
    default:
        c1, ok := deleteCaseI(n)
        if ok {
            return c1
        }
        c2, ok := deleteCaseII(n)
        if ok {
            return c2
        }
    }
    if mightCase {
        x.a, x.b, z.a, z.b = a, b, c, d
        y.a, y.b = x, z
        x.color, y.color, z.color = B, B, B
        return y
    }
    return n
}
func deleteCaseI[K, V any](n *node[K, V]) (*node[K, V], bool) {
    cond := n.color == BB && n.b.color == NB && n.b.a.color == B && n.b.b.color == B
    if !cond {
        return n, false
    }
    x, y, z := n, n.b.a, n.b
    a, b, c, d := x.a, y.a, y.b, z.b
    x.a, x.b = a, b
    z.a, z.b = c, redden(d)
    z.color = B
    y.a, y.b = x, balance(z)
    x.color, y.color, z.color = B, B, B
    return y, true
}
func deleteCaseII[K, V any](n *node[K, V]) (*node[K, V], bool) {
    cond := n.color == BB && n.a.color == NB && n.a.a.color == B && n.a.b.color == B
    if !cond {
        return n, false
    }
    x, y, z := n.a, n.a.b, n
    a, b, c, d := x.a, y.a, y.b, z.b
    x.a, x.b = redden(a), b
    z.a, z.b = c, d
    x.color = B
    y.a, y.b = balance(x), z
    x.color, y.color, z.color = B, B, B
    return y, true
}
func (r *rbmap[K, V]) Delete(key K) {
    r.root = blacken(r.del(r.root, key))
}
func (r *rbmap[K, V]) del(n *node[K, V], key K) *node[K, V] {
    if n == r.leaf {
        return r.leaf
    }
    if r.less(key, n.key) {
        n.a = r.del(n.a, key)
        n = r.bubble(n)
    } else if r.less(n.key, key) {
        n.b = r.del(n.b, key)
        n = r.bubble(n)
    } else {
        return r.remove(n)
    }
    return n
}
func (r *rbmap[K, V]) remove(n *node[K, V]) *node[K, V] {
    //fmt.Println("remove: ")
    //r.Preorder()
    //fmt.Println()
    if n == r.leaf {
        return r.leaf
    }
    if n.color == R && n.a == r.leaf && n.b == r.leaf {
        return r.leaf
    }
    if n.color == B && n.a == r.leaf && n.b == r.leaf {
        return r.bbleaf
    }
    if n.color == B && n.a == r.leaf && n.b != r.leaf && n.b.color == R {
        n.b.color = B
        return n.b
    }
    if n.color == B && n.b == r.leaf && n.a != r.leaf && n.a.color == R {
        n.a.color = B
        return n.a
    }
    //chasing same pointers twice. can optimize by
    // making max return a *node and passing that in to removeMax.
    n.key, n.value = r.max(n.a)
    n.a = r.removeMax(n.a)
    n = r.bubble(n)
    return n
}
func (r *rbmap[K, V]) max(n *node[K, V]) (K, V) {
    for n.b != r.leaf {
        n = n.b
    }
    return n.key, n.value
}
func (r *rbmap[K, V]) removeMax(n *node[K, V]) *node[K, V] {
    if n.b == r.leaf {
        return r.remove(n)
    }
    n.b = r.removeMax(n.b)
    return r.bubble(n)
}
func (r *rbmap[K, V]) bubble(n *node[K, V]) *node[K, V] {
    //fmt.Println("remove: ")
    //r.Preorder()
    //fmt.Println()
    if n.a.color == BB || n.b.color == BB {
        n.color = blacker(n.color)
        n.a = r.redder(n.a)
        n.b = r.redder(n.b)
        return balance(n)
    }
    return balance(n)
}
func (r *rbmap[K, V]) redder(n *node[K, V]) *node[K, V] {
    if n == r.bbleaf {
        return r.leaf
    }
    n.color = redder(n.color)
    return n
}
func redder(c color) color {
    switch c {
    case R:
        return NB
    case B:
        return R
    case BB:
        return B
    case NB:
        // can't happen
        panic("impossible")
    }
    panic("why come here")
}
func blacker(c color) color {
    switch c {
    case NB:
        return R
    case R:
        return B
    case B:
        return BB
    default:
        // BB cannot be blackened further
        panic("unmöglish")
    }
}
func (r rbmap[K, V]) CheckInvariants() {
    if r.root.color != B {
        panic("root must be black")
    }
    ys := make([]int, 0)
    xs := &ys
    r.check(r.root, 0, xs)
    i := 1
    for i < len(*xs) {
        if (*xs)[i-1] != (*xs)[i] {
            fmt.Println(xs)
            panic("black height not same for all the leaves")
        }
        i++
    }
}
func (r rbmap[K, V]) check(n *node[K, V], bh int, xs *[]int) {
    if n == r.leaf {
        *xs = append(*xs, bh)
        return
    }
    if n.color == R {
        if !colors(n, n.a, n.b, R, B, B) {
            r.Preorder()
            fmt.Println(n, n.a, n.b)
            panic("red node without both children black")
        }
    }
    if n.color == B {
        bh += 1
    }
    r.check(n.a, bh, xs)
    r.check(n.b, bh, xs)
}

func main() {
    a := New[int, string](func(x, y int) bool { return x < y })
    a.Set(12, "twelve")
    b := New[string, int](func(x, y string) bool { return x < y })
    b.Set("twelve", 12)
    b.Set("a", 12)
    b.Set("b", 12)
    b.(*rbmap[string, int]).Preorder()
    fmt.Println(a, b)
}