본문 바로가기
[C++] Data Structure & Algorithm/MST(Minimum Spanning Tree)

[MST] Chapter 01. Disjoint Set (Union-Find)

by song.ift 2023. 8. 14.

GitHub : https://github.com/developeSHG/Data_Structure-Algorithm/commit/131723902448c69701b98d17c20188b35e06dbc0

 

Disjoint Set · developeSHG/Data_Structure-Algorithm@1317239

developeSHG committed Aug 14, 2023

github.com

 

 


 

 

Minimum Spanning Tree => 그래프/트리의 활용정도가 된다.

최소 스패닝 트리 (Minimum Spanning Tree)을 알기 전에 알고 가야할 부분이 "상호 배타적 집합 (Disjoint Set)"이다.

 

DSU라고도 한다.

A*에서 가장 좋은 후보를 찾을 때, 우선순위 큐를 사용을 하면 효율이 좋은 것처럼,

이와 비슷하게 최소 스패닝 트리를 사용할 때, Disjoint Set이 우선순위 큐처럼 최소 스패닝 트리의 좋은 부품처럼 쓰인다

부르는 이름은 Disjoint set || Union Find(합치기 - 찾기)

 

Union Find는 Union과 Find라는 연산을 지원하는 자료구조다.

Find는 대표 원소를 반환하다. (path compression이 이루어진다.)

Union은 서로 다른 disjoint set을 합치는 것이다. 이때 합쳐지는 기준은 보통 union by rank로 이루어진다.

 

#include <iostream>
#include <vector>
#include <list>
#include <stack>
#include <queue>
using namespace std;
#include <thread>

// 그래프/트리 응용
// 오늘의 주제 : 최소 스패닝 트리 (Minimum Spanning Tree)

// 상호 배타적 집합 (Disjoint Set)
// -> 유니온-파인드 Union-Find (합치기-찾기)

// Lineage Battleground (혼종!)
// 혈맹전 + 서바이벌
// 1인 팀 1000명 (팀id 0~999)
// 동맹 (1번팀 + 2번팀 = 1번팀)

void LineageBattleground()
{
	struct User
	{
		int teamId;
		// TODO
	};

	// TODO : UserManager
	vector<User> users;
	for (int i = 0; i < 1000; i++)
	{
		users.push_back(User{ i });
	}

	// 팀 동맹
	// users[1] <-> users[5]
	users[5].teamId = users[1].teamId; // 1

	// 여러 인원에 대해 팀이 동맴한다고 치면 아래처럼 이렇게 진행된다고 하자
	// [0][1][2][3][4]...[999]
	// [1][1][1][1][1]...[2][2][2][2]...[999]

	// teamId=1인 팀과 teamId=2인 팀이 통합
	for (User& user : users)
	{
		if (user.teamId == 1)
			user.teamId = 2;
	}

	// 찾기 연산 O(1)
	// 합치기 연산 O(N)
    
    // 코드에는 1000이라 합치기 연산이 괜찮다고 생각할 수 있지만, 데이터가 커질수록 굉장히 느리다.
    // 이 부분을 효율적으로 만들 수 있는 자료구조가 상호 배타적 집합 (Disjoint Set)
}

// 트리 구조를 이용한 상호 배타적 집합의 표현
// [0] [1] [2] [3] [4] [5]
struct Node
{
	Node* leader;
};

// [2번은 1번 산하], [4, 5번은 3번 산하], [0번은 5번 산하]의 계층구조라 치자
// [1]		[3]
// [2]	 [4][5]
//			[0]

// 효율적이지 않은 무식한 방법으로 구현
// 시간 복잡도 : 트리의 높이에 비례한 시간이 걸림
class NaiveDisjointSet
{
public:
	NaiveDisjointSet(int n) : _parent(n)
	{
		for (int i = 0; i < n; i++)
			_parent[i] = i;
	}

	// 니 대장이 누구니?
	int Find(int u)
	{
		if (u == _parent[u])
			return u;

		return Find(_parent[u]);
	}

	// u와 v를 합친다 (일단 u가 v 밑으로)
	void Merge(int u, int v)
	{
		u = Find(u);
		v = Find(v);

		// 같은 그룹인지를 체크하면 진행할 필요가 없음
		if (u == v)
			return;

		_parent[u] = v;
	}

private:
	vector<int> _parent;
};

// 만약 1번 그룹을 3번 그룹으로 merge 한다 했을 때, 아래처럼 변한다.
//    [3]
// [4][5][1]
//	  [0][2]
 
// 근데 위의 클래스가 효율적이진 않다.
// 시간 복잡도가 트리의 높이에 비례한 시간이 걸리기 때문에 만약 아래처럼 균형이 무너진 트리의 구조가 형성될 수 있다.
// [1]
// [5]
// [4]
// [3]
// [2]
// 리스트와 별반 다를바가 없어진다.

// 트리가 한쪽으로 기우는 문제를 해결?
// 트리를 합칠 때, 항상 [높이가 낮은 트리를] [높이가 높은 트리] 밑으로 -> 1번 그룹이 더 작았으니 2번 그룹으로 합치면 된다는 말 
// (Union-By-Rank) 랭크에 의한 합치기 최적화

// 시간 복잡도 O(Ackermann(n)) = O(1)
class DisjointSet
{
public:
	DisjointSet(int n) : _parent(n), _rank(n, 1)
	{
		for (int i = 0; i < n; i++)
			_parent[i] = i;
	}
   
	// 니 대장이 누구니?
	int Find(int u)
	{
		if (u == _parent[u])
			return u;

		//_parent[u] = Find(_parent[u]);
		//return _parent[u];

		// 위의 주석처리한 코드와 같은 것인데, 이렇게 한 이유는
        // [1]		[3]
		// [2]	 [4][5][0]
		// 		 
        // 결국 5번 산하에 있었던 0도 대장은 3이기때문에, 만약 데이터가 늘어나 트리가 높아질 경우
        // 계속 재귀로 타고 올라가 대장을 찾는 코드가 반복될 수 있기 때문에,
        // 한 번 대장을 찾았으면 equal을 통해 부모를 바꿔준 것
		return _parent[u] = Find(_parent[u]);
	}

	// u와 v를 합친다
	void Merge(int u, int v)
	{
		u = Find(u);
		v = Find(v);

		if (u == v)
			return;

		// u의 트리 높이가 v의 트리 높이보다 더 크면 swap.
        // 아까와 다르게 u와 v의 수치를 판단해서 진행되기 때문에
        // 이렇게만 최적화를 해줘도 트리의 균형이 무너지는 걸 막을 수 있다.
		if (_rank[u] > _rank[v])
			swap(u, v);

		// rank[u] <= rank[v] 보장됨
        // 트리가 더 낮은 그룹대장을 트리가 더 높은 그룹대장의 번호로. (merge)
		_parent[u] = v;

		// 이 코드를 넣은 이유는 만약 트리의 높이가 같다치면
        // [1]		[3]					   [3]
		// [2]	 [4][5]    -- merge --> [4][5][1]
		// [6]		[0]     			   [0][2]
        //                                    [6]
        // 트리의 높이가 늘어나야하기 때문에 ++
		if (_rank[u] == _rank[v])
			_rank[v]++;
	}

private:
	vector<int> _parent;
	vector<int> _rank;
};


int main()
{
	DisjointSet teams(1000);

	teams.Merge(10, 1);
	int teamId = teams.Find(1);
	int teamId2 = teams.Find(10);

	teams.Merge(3, 2);
	int teamId3 = teams.Find(3);
	int teamId4 = teams.Find(2);

	teams.Merge(1, 3);
	int teamId6 = teams.Find(1);
	int teamId7 = teams.Find(3);
}

 

댓글