Trees & Graphs: Traversal & Backtracking
Belonging isn't about being next to each other. It's about tracing our roots to the same origin.
Imagine you're managing a massive computer network. You are constantly receiving new connections ("server A is now connected to server B"), and you need to be able to instantly answer the question: "is server X connected to server Y?", even if they are linked through a long, indirect path of other servers.
A standard graph traversal (DFS/BFS) would be too slow, as you'd have to re-traverse the graph every single time you get a query. This is the exact scenario where the Union-Find data structure (also called Disjoint Set Union or DSU) is the perfect tool.
Let's revisit the problem we discussed in the graph traversal chapter, but this time with a focus on the Union-Find data structure.
Problem: Number of Provinces (LeetCode Link): Given a list of connections between cities, count how many connected components (provinces) there are. 💡 HintHow can you efficiently group cities together? What if you could merge the groups for two cities every time you see a direct connection?
The Union-Find data structure does one thing incredibly well: it keeps track of which group (or "disjoint set") an element belongs to. It's built on a simple idea: every group has one element that acts as its "representative" or "root."
It supports two main operations:
find(i): Tells you the representative of the group that element i is in. If find(i) == find(j), then i and j are in the same group.union(i, j): Merges the groups that i and j belong to into a single group.Hence the name "Union-Find"!
We can implement this using a simple array, parent, where parent[i] stores the parent of element i. An element is its own parent if it's the representative of its group.
class UnionFind:
def __init__(self, size):
# Initially, every element is its own parent.
self.parent = list(range(size))
def find(self, i):
# Traverse up until we find the root (the element that is its own parent).
while self.parent[i] != i:
i = self.parent[i]
return i
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
# Make the root of i point to the root of j.
self.parent[root_i] = root_j
class UnionFind {
private int[] parent;
public UnionFind(int size) {
parent = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i; // Initially, every element is its own parent.
}
}
public int find(int i) {
// Traverse up until we find the root (the element that is its own parent).
while (parent[i] != i) {
i = parent[i];
}
return i;
}
public void union(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
// Make the root of i point to the root of j.
parent[rootI] = rootJ;
}
}
}
#include <vector>
#include <numeric> // For std::iota
class UnionFind {
std::vector<int> parent;
public:
UnionFind(int size) {
parent.resize(size);
std::iota(parent.begin(), parent.end(), 0); // Initialize parent[i] to i
}
int find(int i) {
// Traverse up until we find the root (the element that is its own parent).
while (parent[i] != i) {
i = parent[i];
}
return i;
}
void unite(int i, int j) {
int root_i = find(i);
int root_j = find(j);
if (root_i != root_j) {
// Make the root of i point to the root of j.
parent[root_i] = root_j;
}
}
};
This works, but it can create very tall, unbalanced trees. In the worst case, find could take time. We can do much better with two key optimizations.
When we call find(i), we traverse up the tree to find the root. Path Compression is a simple but powerful trick: on the way back down, we make every node we visited point directly to the root. This dramatically flattens the tree.
def find(self, i):
if self.parent[i] != i:
# Path Compression: make the parent of i point directly to the root.
self.parent[i] = self.find(self.parent[i])
return self.parent[i]
public int find(int i) {
if (parent[i] != i) {
// Path Compression: make the parent of i point directly to the root.
parent[i] = find(parent[i]);
}
return parent[i];
}
int find(int i) {
if (parent[i] != i) {
// Path Compression: make the parent of i point directly to the root.
parent[i] = find(parent[i]);
}
return parent[i];
}
When we call union(i, j), instead of arbitrarily connecting one root to the other, we should be smart about it. Union by Rank involves tracking the "rank" (a rough measure of the height) of each tree. We always attach the root of the shorter tree to the root of the taller tree. This helps keep the trees from getting too tall.
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
# Union by Rank
if self.rank[root_i] > self.rank[root_j]:
self.parent[root_j] = root_i
elif self.rank[root_i] < self.rank[root_j]:
self.parent[root_i] = root_j
else:
self.parent[root_j] = root_i
self.rank[root_i] += 1
public void union(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
// Union by Rank
if (rank[rootI] > rank[rootJ]) {
parent[rootJ] = rootI;
} else if (rank[rootI] < rank[rootJ]) {
parent[rootI] = rootJ;
} else {
parent[rootJ] = rootI;
rank[rootI]++;
}
}
}
void unite(int i, int j) {
int root_i = find(i);
int root_j = find(j);
if (root_i != root_j) {
// Union by Rank
if (rank[root_i] > rank[root_j]) {
parent[root_j] = root_i;
} else if (rank[root_i] < rank[root_j]) {
parent[root_i] = root_j;
} else {
parent[root_j] = root_i;
rank[root_i]++;
}
}
}
Here is the final, optimized code that you should use in an interview. With both optimizations, the time complexity for find and union becomes nearly constant.
class UnionFind:
def __init__(self, size):
self.parent = list(range(size))
self.rank = [1] * size
def find(self, i):
if self.parent[i] == i:
return i
# Path Compression
self.parent[i] = self.find(self.parent[i])
return self.parent[i]
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
# Union by Rank
if self.rank[root_i] > self.rank[root_j]:
self.parent[root_j] = root_i
elif self.rank[root_i] < self.rank[root_j]:
self.parent[root_i] = root_j
else:
self.parent[root_j] = root_i
self.rank[root_i] += 1
return True
return False
class UnionFind {
private int[] parent;
private int[] rank;
public UnionFind(int size) {
parent = new int[size];
rank = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
rank[i] = 1;
}
}
public int find(int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent[i]); // Path Compression
}
public boolean union(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
if (rank[rootI] > rank[rootJ]) {
parent[rootJ] = rootI;
} else if (rank[rootI] < rank[rootJ]) {
parent[rootI] = rootJ;
} else {
parent[rootJ] = rootI;
rank[rootI]++;
}
return true;
}
return false;
}
}
class UnionFind {
std::vector<int> parent;
std::vector<int> rank;
public:
UnionFind(int size) {
parent.resize(size);
rank.assign(size, 1);
std::iota(parent.begin(), parent.end(), 0);
}
int find(int i) {
if (parent[i] == i) return i;
return parent[i] = find(parent[i]); // Path Compression
}
bool unite(int i, int j) {
int root_i = find(i);
int root_j = find(j);
if (root_i != root_j) {
if (rank[root_i] > rank[root_j]) {
parent[root_j] = root_i;
} else if (rank[root_i] < rank[root_j]) {
parent[root_i] = root_j;
} else {
parent[root_j] = root_i;
rank[root_i]++;
}
return true;
}
return false;
}
};
N nodes, but one extra edge was added. Find the edge that can be removed to make the graph a tree again.
💡 HintIterate through the edges. For each edge (u, v), try to perform a union(u, v). If u and v are already in the same set before the union, that means adding this edge created a cycle. This is your redundant edge.Union-Find is the engine behind Kruskal's algorithm, a famous method for finding a Minimum Spanning Tree (MST). The algorithm works by sorting all the edges in a graph by weight, from smallest to largest. It then iterates through the sorted edges, adding an edge to the MST only if it does not form a cycle with the edges already added. Union-Find is used to perform this cycle check with near-constant time complexity, making the algorithm highly efficient.