code
src/union_find.hpp
#pragma once
#include <vector>
#include <algorithm>
#include <map>
#include <numeric>
#include <string>
struct UnionFind
{
UnionFind(int n)
{
parent = std::vector<int>(n, 0);
std::iota(parent.begin(), parent.end(), 0);
rank = std::vector<int>(n, 0);
}
int find(int x)
{
if (parent[x] != x) {
parent[x] = find(parent[x]);
}
return parent[x];
}
void _union(int x, int y)
{
auto px = find(x);
auto py = find(y);
if (px == py) {
return;
}
if (rank[px] < rank[py]) {
std::swap(px, py);
}
parent[py] = px;
if (rank[px] == rank[py]) {
rank[px] += 1;
}
}
std::vector<std::vector<int>> groups()
{
std::map<int, std::vector<int>> group_dict;
for (int i = 0, n = (int)parent.size(); i < n; ++i) {
group_dict[find(i)].push_back(i);
}
std::vector<std::vector<int>> ret;
for (auto &pair : group_dict) {
ret.emplace_back(std::move(pair.second));
}
return ret;
}
std::vector<int> group_of(int x)
{
for (const auto &g : groups()) {
if (std::find(g.begin(), g.end(), x) != g.end()) {
return g;
}
}
throw std::invalid_argument("something went wrong, node: " +
std::to_string(x));
}
std::vector<int> parent_() const { return parent; }
std::vector<int> rank_() const { return rank; }
private:
std::vector<int> parent;
std::vector<int> rank;
};
pybind11_union_find/init.py
from pybind11_union_find_ import * # noqa: F403
from pybind11_union_find_ import __version__ # noqa: F401
from typing import List
from collections import defaultdict
class PythonUnionFind:
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x: int) -> int:
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x: int, y: int) -> None:
px, py = self.find(x), self.find(y)
if px == py:
return
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
def groups(self) -> List[List[int]]:
group_dict = defaultdict(list)
for i in range(len(self.parent)):
root = self.find(i)
group_dict[root].append(i)
return list(group_dict.values())
def group_of(self, x) -> List[int]:
for g in self.groups():
if x in g:
return g
raise Exception(f"something went wrong, node: {x}")
tests/test_basic.py
from pybind11_union_find import UnionFind as PybindUnionFind
from pybind11_union_find import PythonUnionFind
from pybind11_union_find import __version__
def test_version():
assert __version__ == "0.0.2"
def test_identical_API():
for UnionFind in [PythonUnionFind, PybindUnionFind]:
print(UnionFind.__name__)
uf = UnionFind(5)
uf.union(0, 2)
uf.union(1, 3)
uf.union(2, 4)
assert uf.find(0) == 0
assert uf.find(1) == 1
assert uf.find(2) == 0
assert uf.find(3) == 1
assert uf.find(4) == 0
groups = uf.groups()
assert groups == [[0, 2, 4], [1, 3]]
assert uf.group_of(2) == [0, 2, 4]
assert uf.parent == [0, 1, 0, 1, 0]
assert uf.rank == [1, 1, 0, 0, 0]