code

src/union_find.hpp

#pragma once

#include <vector>
#include <algorithm>
#include <map>
#include <numeric>
#include <string>

struct UnionFind
{
    UnionFind(int n)
    {
        parent = std::vector<int>(n, 0);
        std::iota(parent.begin(), parent.end(), 0);
        rank = std::vector<int>(n, 0);
    }

    int find(int x)
    {
        if (parent[x] != x) {
            parent[x] = find(parent[x]);
        }
        return parent[x];
    }

    void _union(int x, int y)
    {
        auto px = find(x);
        auto py = find(y);

        if (px == py) {
            return;
        }

        if (rank[px] < rank[py]) {
            std::swap(px, py);
        }

        parent[py] = px;
        if (rank[px] == rank[py]) {
            rank[px] += 1;
        }
    }

    std::vector<std::vector<int>> groups()
    {
        std::map<int, std::vector<int>> group_dict;
        for (int i = 0, n = (int)parent.size(); i < n; ++i) {
            group_dict[find(i)].push_back(i);
        }
        std::vector<std::vector<int>> ret;
        for (auto &pair : group_dict) {
            ret.emplace_back(std::move(pair.second));
        }
        return ret;
    }

    std::vector<int> group_of(int x)
    {
        for (const auto &g : groups()) {
            if (std::find(g.begin(), g.end(), x) != g.end()) {
                return g;
            }
        }
        throw std::invalid_argument("something went wrong, node: " +
                                    std::to_string(x));
    }

    std::vector<int> parent_() const { return parent; }
    std::vector<int> rank_() const { return rank; }

  private:
    std::vector<int> parent;
    std::vector<int> rank;
};

pybind11_union_find/init.py

from pybind11_union_find_ import *  # noqa: F403
from pybind11_union_find_ import __version__  # noqa: F401

from typing import List
from collections import defaultdict


class PythonUnionFind:
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, x: int) -> int:
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x: int, y: int) -> None:
        px, py = self.find(x), self.find(y)
        if px == py:
            return

        if self.rank[px] < self.rank[py]:
            px, py = py, px

        self.parent[py] = px
        if self.rank[px] == self.rank[py]:
            self.rank[px] += 1

    def groups(self) -> List[List[int]]:
        group_dict = defaultdict(list)
        for i in range(len(self.parent)):
            root = self.find(i)
            group_dict[root].append(i)
        return list(group_dict.values())

    def group_of(self, x) -> List[int]:
        for g in self.groups():
            if x in g:
                return g
        raise Exception(f"something went wrong, node: {x}")

tests/test_basic.py

from pybind11_union_find import UnionFind as PybindUnionFind
from pybind11_union_find import PythonUnionFind
from pybind11_union_find import __version__


def test_version():
    assert __version__ == "0.0.2"


def test_identical_API():
    for UnionFind in [PythonUnionFind, PybindUnionFind]:
        print(UnionFind.__name__)
        uf = UnionFind(5)
        uf.union(0, 2)
        uf.union(1, 3)
        uf.union(2, 4)
        assert uf.find(0) == 0
        assert uf.find(1) == 1
        assert uf.find(2) == 0
        assert uf.find(3) == 1
        assert uf.find(4) == 0
        groups = uf.groups()
        assert groups == [[0, 2, 4], [1, 3]]
        assert uf.group_of(2) == [0, 2, 4]
        assert uf.parent == [0, 1, 0, 1, 0]
        assert uf.rank == [1, 1, 0, 0, 0]