A Generic Red-Black Tree Implementation in Golang

To balance the tree CLRS implementation rotates the subtrees. The functional implementations rewrites the tree [1,2]. Thinking in terms of rewriting instead of rotating feels like looking at the problem in a proper coordinate system. My implementation here is translation of Matt Might's Haskell code in go.

The generics syntax is easy on eyes. I like it.

References

1. The missing method: Deleting from Okasaki's red-black trees by Matt Might
2. Red-Black Trees in a Functional Setting by Okasaki
3. The Next Step for Generics by Ian Lance Taylor and Robert Griesemer

Code

playground link
// A generic red-black tree implementation from the
// functional implementation by Matt Might[1] and Okasaki.
//
// [1] http://matt.might.net/articles/red-black-delete/
//
// Author: Pratik Deoghare


package main

import (
        "fmt"
)

type Map[K, V any] interface {
        Get(key K) (value V, ok bool)
        Set(key K, value V)
        Delete(key K)
}

func New[K, V any](less func(K, K) bool) Map[K, V] {
        leaf := &node[K, V]{
                color: B,
        }
        leaf.a = leaf
        leaf.b = leaf
        bbleaf := &node[K, V]{
                color: BB,
        }
        bbleaf.a = leaf
        bbleaf.b = leaf
        return &rbmap[K, V]{
                less:   less,
                leaf:   leaf,
                bbleaf: bbleaf,
                root:   leaf,
        }
}

type color uint8

const (
        R  color = 0
        B  color = 1
        BB color = 2
        NB color = 3
)

type node[K, V any] struct {
        color color
        key   K
        value V
        a     *node[K, V]
        b     *node[K, V]
}
type rbmap[K, V any] struct {
        root   *node[K, V]
        leaf   *node[K, V] // the leaf always Black. We don't touch it. Its a sacred leaf.
        bbleaf *node[K, V] // this is used for deletion
        less   func(K, K) bool
}

func (r rbmap[K, V]) Preorder() {
        r.preorder(r.root, "")
}
func (r rbmap[K, V]) preorder(n *node[K, V], tab string) {
        if n == r.leaf {
                return
        }
        fmt.Println(tab, n.key, "=>", n.value, n.color)
        r.preorder(n.a, ":"+tab)
        r.preorder(n.b, ":"+tab)
}
func (r rbmap[K, V]) Inorder() {
        panic("implement me")
}
func (r rbmap[K, V]) Get(key K) (value V, ok bool) {
        n := r.root
        for n != r.leaf {
                if r.less(key, n.key) {
                        n = n.a
                } else if r.less(n.key, key) {
                        n = n.b
                } else {
                        return n.value, true
                }
        }
        return Nil[V](), false
}
func Nil[T any]() T {
        var zero T
        return zero
}
func (r *rbmap[K, V]) Set(key K, value V) {
        r.root = blacken(r.insert(r.root, key, value))
}
func blacken[K, V any](n *node[K, V]) *node[K, V] {
        n.color = B
        return n
}
func redden[K, V any](n *node[K, V]) *node[K, V] {
        n.color = R
        return n
}
func (r *rbmap[K, V]) insert(n *node[K, V], key K, value V) *node[K, V] {
        if n == r.leaf {
                return &node[K, V]{
                        color: R,
                        key:   key,
                        value: value,
                        a:     r.leaf,
                        b:     r.leaf,
                }
        }
        if r.less(key, n.key) {
                n.a = r.insert(n.a, key, value)
                n = balance(n)
        } else if r.less(n.key, key) {
                n.b = r.insert(n.b, key, value)
                n = balance(n)
        } else {
                n.value = value
        }
        return n
}
func colors[K, V any](n1, n2, n3 *node[K, V], c1, c2, c3 color) bool {
        return n1.color == c1 && n2.color == c2 && n3.color == c3
}
func balance[K, V any](n *node[K, V]) *node[K, V] {
        var x, y, z *node[K, V]
        var a, b, c, d *node[K, V]
        okasakiCase := false
        switch {
        case colors(n, n.a, n.a.a, B, R, R):
                x, y, z = n.a.a, n.a, n
                a, b, c, d = x.a, x.b, y.b, z.b
                okasakiCase = true
        case colors(n, n.a, n.a.b, B, R, R):
                x, y, z = n.a, n.a.b, n
                a, b, c, d = x.a, y.a, y.b, z.b
                okasakiCase = true
        case colors(n, n.b, n.b.a, B, R, R):
                x, y, z = n, n.b.a, n.b
                a, b, c, d = x.a, y.a, y.b, z.b
                okasakiCase = true
        case colors(n, n.b, n.b.b, B, R, R):
                x, y, z = n, n.b, n.b.b
                a, b, c, d = x.a, y.a, z.a, z.b
                okasakiCase = true
        }
        if okasakiCase {
                x.a, x.b, z.a, z.b = a, b, c, d
                y.a, y.b = x, z
                x.color, y.color, z.color = B, R, B
                return y
        }
        mightCase := false
        switch {
        case colors(n, n.a, n.a.a, BB, R, R):
                x, y, z = n.a.a, n.a, n
                a, b, c, d = x.a, x.b, y.b, z.b
                mightCase = true
        case colors(n, n.a, n.a.b, BB, R, R):
                x, y, z = n.a, n.a.b, n
                a, b, c, d = x.a, y.a, y.b, z.b
                mightCase = true
        case colors(n, n.b, n.b.a, BB, R, R):
                x, y, z = n, n.b.a, n.b
                a, b, c, d = x.a, y.a, y.b, z.b
                mightCase = true
        case colors(n, n.b, n.b.b, BB, R, R):
                x, y, z = n, n.b, n.b.b
                a, b, c, d = x.a, y.a, z.a, z.b
                mightCase = true
        default:
                c1, ok := deleteCaseI(n)
                if ok {
                        return c1
                }
                c2, ok := deleteCaseII(n)
                if ok {
                        return c2
                }
        }
        if mightCase {
                x.a, x.b, z.a, z.b = a, b, c, d
                y.a, y.b = x, z
                x.color, y.color, z.color = B, B, B
                return y
        }
        return n
}
func deleteCaseI[K, V any](n *node[K, V]) (*node[K, V], bool) {
        cond := n.color == BB && n.b.color == NB && n.b.a.color == B && n.b.b.color == B
        if !cond {
                return n, false
        }
        x, y, z := n, n.b.a, n.b
        a, b, c, d := x.a, y.a, y.b, z.b
        x.a, x.b = a, b
        z.a, z.b = c, redden(d)
        z.color = B
        y.a, y.b = x, balance(z)
        x.color, y.color, z.color = B, B, B
        return y, true
}
func deleteCaseII[K, V any](n *node[K, V]) (*node[K, V], bool) {
        cond := n.color == BB && n.a.color == NB && n.a.a.color == B && n.a.b.color == B
        if !cond {
                return n, false
        }
        x, y, z := n.a, n.a.b, n
        a, b, c, d := x.a, y.a, y.b, z.b
        x.a, x.b = redden(a), b
        z.a, z.b = c, d
        x.color = B
        y.a, y.b = balance(x), z
        x.color, y.color, z.color = B, B, B
        return y, true
}
func (r *rbmap[K, V]) Delete(key K) {
        r.root = blacken(r.del(r.root, key))
}
func (r *rbmap[K, V]) del(n *node[K, V], key K) *node[K, V] {
        if n == r.leaf {
                return r.leaf
        }
        if r.less(key, n.key) {
                n.a = r.del(n.a, key)
                n = r.bubble(n)
        } else if r.less(n.key, key) {
                n.b = r.del(n.b, key)
                n = r.bubble(n)
        } else {
                return r.remove(n)
        }
        return n
}
func (r *rbmap[K, V]) remove(n *node[K, V]) *node[K, V] {
        //fmt.Println("remove: ")
        //r.Preorder()
        //fmt.Println()
        if n == r.leaf {
                return r.leaf
        }
        if n.color == R && n.a == r.leaf && n.b == r.leaf {
                return r.leaf
        }
        if n.color == B && n.a == r.leaf && n.b == r.leaf {
                return r.bbleaf
        }
        if n.color == B && n.a == r.leaf && n.b != r.leaf && n.b.color == R {
                n.b.color = B
                return n.b
        }
        if n.color == B && n.b == r.leaf && n.a != r.leaf && n.a.color == R {
                n.a.color = B
                return n.a
        }
        //chasing same pointers twice. can optimize by
        // making max return a *node and passing that in to removeMax.
        n.key, n.value = r.max(n.a)
        n.a = r.removeMax(n.a)
        n = r.bubble(n)
        return n
}
func (r *rbmap[K, V]) max(n *node[K, V]) (K, V) {
        for n.b != r.leaf {
                n = n.b
        }
        return n.key, n.value
}
func (r *rbmap[K, V]) removeMax(n *node[K, V]) *node[K, V] {
        if n.b == r.leaf {
                return r.remove(n)
        }
        n.b = r.removeMax(n.b)
        return r.bubble(n)
}
func (r *rbmap[K, V]) bubble(n *node[K, V]) *node[K, V] {
        //fmt.Println("remove: ")
        //r.Preorder()
        //fmt.Println()
        if n.a.color == BB || n.b.color == BB {
                n.color = blacker(n.color)
                n.a = r.redder(n.a)
                n.b = r.redder(n.b)
                return balance(n)
        }
        return balance(n)
}
func (r *rbmap[K, V]) redder(n *node[K, V]) *node[K, V] {
        if n == r.bbleaf {
                return r.leaf
        }
        n.color = redder(n.color)
        return n
}
func redder(c color) color {
        switch c {
        case R:
                return NB
        case B:
                return R
        case BB:
                return B
        case NB:
                // can't happen
                panic("impossible")
        }
        panic("why come here")
}
func blacker(c color) color {
        switch c {
        case NB:
                return R
        case R:
                return B
        case B:
                return BB
        default:
                // BB cannot be blackened further
                panic("unmöglish")
        }
}
func (r rbmap[K, V]) CheckInvariants() {
        if r.root.color != B {
                panic("root must be black")
        }
        ys := make([]int, 0)
        xs := &ys
        r.check(r.root, 0, xs)
        i := 1
        for i < len(*xs) {
                if (*xs)[i-1] != (*xs)[i] {
                        fmt.Println(xs)
                        panic("black height not same for all the leaves")
                }
                i++
        }
}
func (r rbmap[K, V]) check(n *node[K, V], bh int, xs *[]int) {
        if n == r.leaf {
                *xs = append(*xs, bh)
                return
        }
        if n.color == R {
                if !colors(n, n.a, n.b, R, B, B) {
                        r.Preorder()
                        fmt.Println(n, n.a, n.b)
                        panic("red node without both children black")
                }
        }
        if n.color == B {
                bh += 1
        }
        r.check(n.a, bh, xs)
        r.check(n.b, bh, xs)
}

func main() {
        a := New[int, string](func(x, y int) bool { return x < y })
        a.Set(12, "twelve")
        b := New[string, int](func(x, y string) bool { return x < y })
        b.Set("twelve", 12)
        b.Set("a", 12)
        b.Set("b", 12)
        b.(*rbmap[string, int]).Preorder()
        fmt.Println(a, b)
}