Union-Find – 陪你刷題

Union-Find data structure ,又稱為 disjoint-set data structure,用於處理 disjoint set 的查詢與合併問題,最常見應用,用來解決圖中的 connected component 問題 。

Disjoint set

先看 "set" 這個單字,set 在 computer science 領域指的是一組資料的集合,set 內部的資料是不重複的,資料之間的順序並不重要。

而 disjoint set 表示數個 set 之間,擁有的元素都不相同,彼此互斥(disjoint)。例如 A = {1, 2, 3}, B = {4, 5} ,A 和 B 即為 disjoint set 。

Union-Find Algorithm

Union-Find data structure 是一種 forest 結構,forest 是一種 N-way Tree 結構,互相連通的節點放在同一組 set ,並以 forest 來表示,任意選擇其中一個節點作為 root 。

Union-Find 提供以下兩種操作:

  • Find : 確定元素屬於哪一個子集 (找到節點的 root )。可以使用此操作確定兩個元素是否屬於同一子集。
  • Union : 將兩個子集合併為同一子集。

Find 函數背後概念為找到節點的 root ,如果要確認兩個節點是否屬於同一子集,只要分別找兩節點的 root ,如果一樣,即代表屬於同一子集。

Find 的時間複雜度,最差狀況就是遍歷整棵樹,若樹呈現極度不平衡如同一個 linked list ,時間複雜度為 O (n) 。

Union 能夠將兩個點所屬的子集進行合併,合併最簡單方法是將一個子集的 root 直接作為另外一個子集 root 的子節點即可,如下圖:

Union 的實現需要依靠 Find ,因此時間複雜度最差將為 O (n) 。

Find 跟 Union 的執行時間都是線性等級,這樣的資料結構顯然不是好用的,你可能會想到平衡二元樹,其樹的高度維持在 O (log n) ,因此想要改善時間複雜度,最好是避免樹的不平衡。

Weighted Quick Union - Tree 平衡優化

以上圖來說,如果由 2 作為 0 的子節點,合併過後的樹高度會比較小,根據這個觀察可以歸納出,如果兩個子集要合併,應該讓高度較小的子集合併到比較大的子集下。

Path Compression

Path compression 可以進一步優化,讓每個節點直接連到他的 root 節點,這樣 Find 跟 Union 操作的時間複雜度可以降低到 O (1) 。

要達成壓縮,只需在 Find 中加入以下程式碼,建議可以自己畫圖感受程式碼的行為,while 終止條件為節點的 parent 等於自己,代表無法再壓縮了,壓縮完的 tree 的高度一定不大於 3 。

int Find(int a) {
    while (a != parents[a]) {
        // a's grandparent is now a's parent
        parents[a] = parents[parents[a]];
        a = parents[a];
    }
    return a;
}

Union by size / rank with path compression

透過 path compression 後的樹可以維持高度在 3 以內 ,若要再 union 兩個壓縮後的子集,可以採取 union by size 技巧,將子節點比較少的 root 加入比較大的子集。

另外一種是依照 rank 來排序,起初每個點的 rank 均為 0 ,依據 rank 大小來決定如何合併,rank 大的子集合併小的,同時前者的 rank 往上增加。

如果有了 path compression ,是否還需要依據 size 或 rank 來決定合併順序?

有了 path compression , forest 的高度必定不超過 3 ,維持在常數等級,由時間複雜度角度來看確實沒有差別,但由下圖合併完的兩個 case 來看,情況 1 會另外做好幾次的 path compression 來將第三層的節點移到第二層,因此依據 size 或 rank 來決定合併確實更佳優化。

Leetcode #200 Number of Islands

題目所求就是從圖中找出 component 數量。

class UnionFind 負責處理 union-find data structure ,在 class UnionFind 的 Constructor 中,針對每個 land 節點先將其 parent 設為自己,也代表每個節點自己都是一個 connected component ,而 water 節點的 parent 設為 -1 。

透過 DFS 由每個 land 出發,走訪過的點直接修改為 water 節點,避免被重複執行 Union ,每執行一次 Union ,代表少了一個 component 數量,所有 land 都透過 DFS 走訪過,即可得 component 數量。

class UnionFind {
    public:
    UnionFind(vector<vector<char>>& grid) {
        count = 0;
        for (int i=0; i<grid.size (); i++)
        {
            for (int j=0; j<grid[0].size (); j++)
            {
                if (grid[i][j] == '1')
                {
                    parent.push_back (i*grid[0].size() + j);
                    count++;
                }
                else
                {
                    parent.push_back (-1);
                }
                rank.push_back (i*grid[0].size() + j);
            }
        }
    }

    int find(int i) { // path compression
        while (i != parent[i])
        {
            parent[i] = parent[parent[i]];
            i = parent[i];
        }
        return parent[i];
    }

    void Union(int x, int y) { // union with rank
        int root_x = find (x);
        int root_y = find (y);
        if (root_x != root_y)
        {
            if (rank[root_x] > rank[root_y])
            {
                parent[root_y] = root_x;
            }
            else if (rank[root_y] > rank[root_x])
            {
                parent[root_x] = root_y;
            }
            else
            {
                parent[root_y] = root_x;
                rank[root_x] +=1;
            }
            count--;
        }
    }

    int getCount() const {
        return count;
    }

    private:
    vector<int> parent;
    vector<int> rank;
    int count; // # of connected components
};

class Solution {
    public:
    int numIslands(vector<vector<char>>& grid) {
        int col = grid.size ();
        int row = grid[0].size ();

        UnionFind uf (grid);
        for (int i=0; i<col; i++)
        {
            for (int j=0; j<row; j++)
            {
                if (grid[i][j] == '1')
                {
                    grid[i][j] = '0';
                    if (i-1 >= 0 && grid[i-1][j] == '1') uf.Union (i*row+j, (i-1)*row+j);
                    if (i+1 < col && grid[i+1][j] == '1') uf.Union (i*row+j, (i+1)*row+j);
                    if (j-1 >=0 && grid[i][j-1] == '1') uf.Union (i*row+j, i*row+(j-1));
                    if (j+1 < row && grid[i][j+1] == '1') uf.Union (i*row+j, i*row+(j+1));
                }
            }
        }
        return uf.getCount ();
    }
};

時間複雜度

時間複雜度為 O (M x N) ,M, N 分別為輸入陣列的長跟寬,最壞狀況下 DFS 會將所有點都走過。

空間複雜度

空間複雜度為 O (M x N) ,每個點需要紀錄其 parent 與 rank 。

Leetcode #684 Redundant Connection

本題要問去除哪一個 edge 來避免圖中形成 cycle ,回想 Union-Find algorithm ,Union 方法就是將兩個子集合在一起,並形成一個新的 forest 結構,但是兩個子集可以合為一個 forest ,是根基於兩個子集沒有交集,也就是說他們的 root 不一樣,若是兩子集有共同 root ,又將兩個點連起來,就會在 forest 中形成 cycle 。

將每個 edge 的兩點 union 起來,在 Union 操作內將執行 Find 找兩點的 root, 若 root 一樣,代表找到會形成 cycle 的 edge 。

class Solution {
public:
    vector<int> parent;
    vector<int> rank;
    int Find (int x)
    {
        while (x != parent[x])
        {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }
    bool Union (int x, int y)
    {
        int rootX = Find (x);
        int rootY = Find (y);
        if (rootX == rootY)
        {
            return false;
        }
        else
        {
            if (rank[x] > rank[y])
            {
                parent[y] = x;
            }
            else if (rank[y] > rank[x])
            {
                parent[x] = y;
            }
            else
            {
                parent[rootX] = rootY;
                rank[y]++;
            }
        }
        return true;
    }
    vector<int> findRedundantConnection(vector<vector<int>>& edges)
    {
        for (int i=0; i<= edges.size(); i++)
        {
            parent.push_back (i);
            rank.push_back (0);
        }

        for (int i=0; i<edges.size(); i++)
        {
            if (!Union (edges[i][0], edges[i][1]))
            {
                return vector<int> {edges[i][0], edges[i][1]};
            }
        }
        return {};
    }
};

時間複雜度

O (N) ,其中 N 可能為題目給的 edge 數量或是 node 數量。

空間複雜度

O (N) ,N 為 node 數量。

延伸問題 Leetcode #261, #547, #721

Reference

  1. Disjoint Set Union CP-Algorithms
  2. Disjoint-set data structure - wikipedia
  3. 普林斯頓課程學習筆記1-union-find
  4. 演算法筆記 Set
  5. http://web.ntnu.edu.tw/~algo/SpanningTree.html#2
  6. Union-Find算法详解 - labuladong算法博客

Updated on 2021-02-24 20:16:38 星期三