Union Find

Trees & Graphs: Traversal & Backtracking

Belonging isn't about being next to each other. It's about tracing our roots to the same origin.

Imagine you're managing a massive computer network. You are constantly receiving new connections ("server A is now connected to server B"), and you need to be able to instantly answer the question: "is server X connected to server Y?", even if they are linked through a long, indirect path of other servers.

A standard graph traversal (DFS/BFS) would be too slow, as you'd have to re-traverse the graph every single time you get a query. This is the exact scenario where the Union-Find data structure (also called Disjoint Set Union or DSU) is the perfect tool.

Let's revisit the problem we discussed in the graph traversal chapter, but this time with a focus on the Union-Find data structure.

Problem: Number of Provinces (LeetCode Link): Given a list of connections between cities, count how many connected components (provinces) there are. 💡 Hint

The Core Idea: Groups and Representatives

The Union-Find data structure does one thing incredibly well: it keeps track of which group (or "disjoint set") an element belongs to. It's built on a simple idea: every group has one element that acts as its "representative" or "root."

It supports two main operations:

  • find(i): Tells you the representative of the group that element i is in. If find(i) == find(j), then i and j are in the same group.
  • union(i, j): Merges the groups that i and j belong to into a single group.

Hence the name "Union-Find"!

Building the Union-Find Data Structure

We can implement this using a simple array, parent, where parent[i] stores the parent of element i. An element is its own parent if it's the representative of its group.

A Naive Implementation

class UnionFind:
    def __init__(self, size):
        # Initially, every element is its own parent.
        self.parent = list(range(size))

    def find(self, i):
        # Traverse up until we find the root (the element that is its own parent).
        while self.parent[i] != i:
            i = self.parent[i]
        return i

    def union(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            # Make the root of i point to the root of j.
            self.parent[root_i] = root_j

This works, but it can create very tall, unbalanced trees. In the worst case, find could take O(N)O(N) time. We can do much better with two key optimizations.

Optimization 1: Path Compression

When we call find(i), we traverse up the tree to find the root. Path Compression is a simple but powerful trick: on the way back down, we make every node we visited point directly to the root. This dramatically flattens the tree.

def find(self, i):
    if self.parent[i] != i:
        # Path Compression: make the parent of i point directly to the root.
        self.parent[i] = self.find(self.parent[i])
    return self.parent[i]

Optimization 2: Union by Rank (or Size)

When we call union(i, j), instead of arbitrarily connecting one root to the other, we should be smart about it. Union by Rank involves tracking the "rank" (a rough measure of the height) of each tree. We always attach the root of the shorter tree to the root of the taller tree. This helps keep the trees from getting too tall.

def union(self, i, j):
    root_i = self.find(i)
    root_j = self.find(j)
    if root_i != root_j:
        # Union by Rank
        if self.rank[root_i] > self.rank[root_j]:
            self.parent[root_j] = root_i
        elif self.rank[root_i] < self.rank[root_j]:
            self.parent[root_i] = root_j
        else:
            self.parent[root_j] = root_i
            self.rank[root_i] += 1

The Optimized Implementation

Here is the final, optimized code that you should use in an interview. With both optimizations, the time complexity for find and union becomes nearly constant.

class UnionFind:
    def __init__(self, size):
        self.parent = list(range(size))
        self.rank = [1] * size

    def find(self, i):
        if self.parent[i] == i:
            return i
        # Path Compression
        self.parent[i] = self.find(self.parent[i])
        return self.parent[i]

    def union(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            # Union by Rank
            if self.rank[root_i] > self.rank[root_j]:
                self.parent[root_j] = root_i
            elif self.rank[root_i] < self.rank[root_j]:
                self.parent[root_i] = root_j
            else:
                self.parent[root_j] = root_i
                self.rank[root_i] += 1
            return True
        return False

More Problems & Variations

  1. Redundant Connection (LeetCode Link): You are given a graph that started as a tree with N nodes, but one extra edge was added. Find the edge that can be removed to make the graph a tree again. 💡 Hint
  2. Graph Valid Tree (LeetCode Link): Determine if a given graph is a valid tree. 💡 Hint

Bonus Points: Minimum Spanning Trees

Union-Find is the engine behind Kruskal's algorithm, a famous method for finding a Minimum Spanning Tree (MST). The algorithm works by sorting all the edges in a graph by weight, from smallest to largest. It then iterates through the sorted edges, adding an edge to the MST only if it does not form a cycle with the edges already added. Union-Find is used to perform this cycle check with near-constant time complexity, making the algorithm highly efficient.