Union-find forest, also called a disjoint-set forest, is a data structure that stores a collection of disjoint sets.

Operations

Union-find forests support 2 operations:

  • find: find the representative of the set containing a given element.
  • union: merge two sets.

Initialisation

Union-find forests have two internal data structures:

  • parent: a collection of pointers that for each node points to its parent node.
    • Can be thought of as a tree.
    • If v.parent == v, node v is the representative of the set to which it belongs.
  • size: total number of descendants of each node.

Both data structures requires space.

Find

Given a node, the find operation finds the representative of the set to which the node belongs. This is done by recursively following the parent in succession until the node is the representative of the set.

The time complexity of the find operation is equal to the chain of nodes it has to follow along the tree to reach the root node. Thus, a flatter tree (low height) leads to faster find operations. Using the path compression algorithm, which makes every node between the query node and the root point to the root, is one way to accomplish this.

There are multiple approaches to path compression. One very well known approach is path halving, as seen in the following pseudocode:

while x.parent != x:
	x.parent = x.parent.parent
	x = x.parent

Union

union takes two nodes and merges the sets to which they belong respectively. Given two nodes v1 and v2,

  • If v1.parent == v2.parent: no further operation needed.
  • Otherwise, find the roots of v1 and v2 using find.
    • If the sizes of the roots are different, the larger sized tree becomes the parent.
    • If the sizes of the roots are the same, either one can become the parent.
    • In both cases, increase the size of the parent by the size of the absorbed tree.

Implementation

Below is the python implementation of the union-find forest. For both parent and rank, a simple list is used.

class UnionFind:
	def __init__(self, n: int):
		self.parent = list(range(n))
		self.size = [1] * n
 
	def find(self, v: int) -> int:
		while v != self.parent[v]:
			self.parent[v] = self.parent[self.parent[v]] # path halving
			v = self.parent[v]
		return v
 
	def union(self, v1: int, v2: int) -> None:
		r1, r2 = self.find(v1), self.find(v2)
		if r1 == r2:
			return
 
		if self.size[r1] > self.size[r2]:
			self.parent[r2] = r1
			self.size[r1] += self.size[r2]
		else:
			self.parent[r1] = r2
			self.size[r2] += self.size[r1]

References