Union Find
Trees & Graphs: Traversal & Backtracking
Belonging isn't about being next to each other. It's about tracing our roots to the same origin.
Imagine you're managing a massive computer network. You are constantly receiving new connections ("server A is now connected to server B"
), and you need to be able to instantly answer the question: "is server X connected to server Y?"
, even if they are linked through a long, indirect path of other servers.
A standard graph traversal (DFS/BFS) would be too slow, as you'd have to re-traverse the graph every single time you get a query. This is the exact scenario where the Union-Find data structure (also called Disjoint Set Union or DSU) is the perfect tool.
Let's revisit the problem we discussed in the graph traversal chapter, but this time with a focus on the Union-Find data structure.
Problem: Number of Provinces (LeetCode Link): Given a list of connections between cities, count how many connected components (provinces) there are. 💡 HintHow can you efficiently group cities together? What if you could merge the groups for two cities every time you see a direct connection?
The Core Idea: Groups and Representatives
The Union-Find data structure does one thing incredibly well: it keeps track of which group (or "disjoint set") an element belongs to. It's built on a simple idea: every group has one element that acts as its "representative" or "root."
It supports two main operations:
find(i)
: Tells you the representative of the group that elementi
is in. Iffind(i) == find(j)
, theni
andj
are in the same group.union(i, j)
: Merges the groups thati
andj
belong to into a single group.
Hence the name "Union-Find"!
Building the Union-Find Data Structure
We can implement this using a simple array, parent
, where parent[i]
stores the parent of element i
. An element is its own parent if it's the representative of its group.
A Naive Implementation
class UnionFind:
def __init__(self, size):
# Initially, every element is its own parent.
self.parent = list(range(size))
def find(self, i):
# Traverse up until we find the root (the element that is its own parent).
while self.parent[i] != i:
i = self.parent[i]
return i
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
# Make the root of i point to the root of j.
self.parent[root_i] = root_j
class UnionFind {
private int[] parent;
public UnionFind(int size) {
parent = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i; // Initially, every element is its own parent.
}
}
public int find(int i) {
// Traverse up until we find the root (the element that is its own parent).
while (parent[i] != i) {
i = parent[i];
}
return i;
}
public void union(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
// Make the root of i point to the root of j.
parent[rootI] = rootJ;
}
}
}
#include <vector>
#include <numeric> // For std::iota
class UnionFind {
std::vector<int> parent;
public:
UnionFind(int size) {
parent.resize(size);
std::iota(parent.begin(), parent.end(), 0); // Initialize parent[i] to i
}
int find(int i) {
// Traverse up until we find the root (the element that is its own parent).
while (parent[i] != i) {
i = parent[i];
}
return i;
}
void unite(int i, int j) {
int root_i = find(i);
int root_j = find(j);
if (root_i != root_j) {
// Make the root of i point to the root of j.
parent[root_i] = root_j;
}
}
};
This works, but it can create very tall, unbalanced trees. In the worst case, find
could take time. We can do much better with two key optimizations.
Optimization 1: Path Compression
When we call find(i)
, we traverse up the tree to find the root. Path Compression is a simple but powerful trick: on the way back down, we make every node we visited point directly to the root. This dramatically flattens the tree.
def find(self, i):
if self.parent[i] != i:
# Path Compression: make the parent of i point directly to the root.
self.parent[i] = self.find(self.parent[i])
return self.parent[i]
public int find(int i) {
if (parent[i] != i) {
// Path Compression: make the parent of i point directly to the root.
parent[i] = find(parent[i]);
}
return parent[i];
}
int find(int i) {
if (parent[i] != i) {
// Path Compression: make the parent of i point directly to the root.
parent[i] = find(parent[i]);
}
return parent[i];
}
Optimization 2: Union by Rank (or Size)
When we call union(i, j)
, instead of arbitrarily connecting one root to the other, we should be smart about it. Union by Rank involves tracking the "rank" (a rough measure of the height) of each tree. We always attach the root of the shorter tree to the root of the taller tree. This helps keep the trees from getting too tall.
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
# Union by Rank
if self.rank[root_i] > self.rank[root_j]:
self.parent[root_j] = root_i
elif self.rank[root_i] < self.rank[root_j]:
self.parent[root_i] = root_j
else:
self.parent[root_j] = root_i
self.rank[root_i] += 1
public void union(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
// Union by Rank
if (rank[rootI] > rank[rootJ]) {
parent[rootJ] = rootI;
} else if (rank[rootI] < rank[rootJ]) {
parent[rootI] = rootJ;
} else {
parent[rootJ] = rootI;
rank[rootI]++;
}
}
}
void unite(int i, int j) {
int root_i = find(i);
int root_j = find(j);
if (root_i != root_j) {
// Union by Rank
if (rank[root_i] > rank[root_j]) {
parent[root_j] = root_i;
} else if (rank[root_i] < rank[root_j]) {
parent[root_i] = root_j;
} else {
parent[root_j] = root_i;
rank[root_i]++;
}
}
}
The Optimized Implementation
Here is the final, optimized code that you should use in an interview. With both optimizations, the time complexity for find
and union
becomes nearly constant.
class UnionFind:
def __init__(self, size):
self.parent = list(range(size))
self.rank = [1] * size
def find(self, i):
if self.parent[i] == i:
return i
# Path Compression
self.parent[i] = self.find(self.parent[i])
return self.parent[i]
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
# Union by Rank
if self.rank[root_i] > self.rank[root_j]:
self.parent[root_j] = root_i
elif self.rank[root_i] < self.rank[root_j]:
self.parent[root_i] = root_j
else:
self.parent[root_j] = root_i
self.rank[root_i] += 1
return True
return False
class UnionFind {
private int[] parent;
private int[] rank;
public UnionFind(int size) {
parent = new int[size];
rank = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
rank[i] = 1;
}
}
public int find(int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent[i]); // Path Compression
}
public boolean union(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
if (rank[rootI] > rank[rootJ]) {
parent[rootJ] = rootI;
} else if (rank[rootI] < rank[rootJ]) {
parent[rootI] = rootJ;
} else {
parent[rootJ] = rootI;
rank[rootI]++;
}
return true;
}
return false;
}
}
class UnionFind {
std::vector<int> parent;
std::vector<int> rank;
public:
UnionFind(int size) {
parent.resize(size);
rank.assign(size, 1);
std::iota(parent.begin(), parent.end(), 0);
}
int find(int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent[i]); // Path Compression
}
bool unite(int i, int j) {
int root_i = find(i);
int root_j = find(j);
if (root_i != root_j) {
if (rank[root_i] > rank[root_j]) {
parent[root_j] = root_i;
} else if (rank[root_i] < rank[root_j]) {
parent[root_i] = root_j;
} else {
parent[root_j] = root_i;
rank[root_i]++;
}
return true;
}
return false;
}
};
More Problems & Variations
- Redundant Connection (LeetCode Link): You are given a graph that started as a tree with
N
nodes, but one extra edge was added. Find the edge that can be removed to make the graph a tree again. 💡 HintIterate through the edges. For each edge(u, v)
, try to perform aunion(u, v)
. Ifu
andv
are already in the same set before the union, that means adding this edge created a cycle. This is your redundant edge. - Graph Valid Tree (LeetCode Link): Determine if a given graph is a valid tree. 💡 HintA valid tree has two properties: 1) It has no cycles. 2) It is fully connected. You can use Union-Find to check for cycles. After processing all edges, how many connected components should there be?
Bonus Points: Minimum Spanning Trees
Union-Find is the engine behind Kruskal's algorithm, a famous method for finding a Minimum Spanning Tree (MST). The algorithm works by sorting all the edges in a graph by weight, from smallest to largest. It then iterates through the sorted edges, adding an edge to the MST only if it does not form a cycle with the edges already added. Union-Find is used to perform this cycle check with near-constant time complexity, making the algorithm highly efficient.