๋ฌธ์
https://www.acmicpc.net/problem/16398
๋ฐฑ์ค ๋ฌธ์ ์ง - 0x1B๊ฐ - ์ต์ ์ ์ฅ ํธ๋ฆฌ
์๊ณ ๋ฆฌ์ฆ: ์ต์ ์คํจ๋ ํธ๋ฆฌ
ํ์ด

ํฌ๋ฃจ์ค์นผ์ด ์ต์ํ ํฐ๋ผ ํฌ๋ฃจ์ค์นผ → ํ๋ฆผ ์์ผ๋ก ํ์๋ค.
ํฌ๋ฃจ์ค์นผ๋ก๋ ์์ํ ํต๊ณผ๋์ง๋ง, ์ด ๋ฌธ์ ๋ "๋ชจ๋ ๊ฐ์ "์ด ์ฃผ์ด์ง๋ฏ๋ก ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ๋๊ฒ ๋ ํจ์จ์ ์ด๋ค.
ํฌ๋ฃจ์ค์นผ ์๊ณ ๋ฆฌ์ฆ ์ฌ์ฉ
๊ฐ์ ์ด ์ง์ ์ ์ผ๋ก ์ฃผ์ด์ง๋๊ฒ ์๋๋ผ(u v w), ํ๋ ฌ ์์ฒด๋ก ์ฃผ์ด์ง๋ค.
๊ทธ๋ฆฌ๊ณ A → B = B → A ์๋ฐฉํฅ ๊ทธ๋ํ์ด๋ฏ๋ก ์ฃผ๋๊ฐ์ (i) ์ ๊น์ง๋ง ์ฒดํฌํด์ฃผ๋ฉด ๋๋ค.
(๋จ๋ฐฉํฅ์ด๋ผ๋ฉด ๋ชจ๋ ์ ์ฅํด์ค์ผํจ)
๊ทธ ๋ค์์ ์ผ๋ฐ์ ์ธ ํฌ๋ฃจ์ค์นผ ํ์ด์ ๊ฐ๋ค.
- ๋น์ฉ ๊ธฐ์ค์ผ๋ก ๊ฐ์ ์ ๋ ฌ
- ์ ๋ ฌ๋ ๊ฐ์ ๋ค์ ํ์
- ๊ฐ์ ์งํฉ์ด ์๋๋ผ๋ฉด ์ฐ๊ฒฐ ํ ๋น์ฉ ์ ์ฅ
์ ์ฒด ์ฝ๋
# 1. ํฌ๋ฃจ์ค์นผ ์๊ณ ๋ฆฌ์ฆ ํ์ด
# ๋ฉ๋ชจ๋ฆฌ: 102560KB / ์๊ฐ: 948ms
from sys import stdin
input = stdin.readline
N = int(input())
graph = []
parent = list(range(N))
for i in range(N):
line = list(map(int, input().split()))
# ์ฃผ๋๊ฐ์ (์๊ธฐ ์์ ) ์ ๊น์ง๋ง ์ฒดํฌ
for j in range(i):
graph.append((i, j, line[j]))
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(a, b):
a, b = find(a), find(b)
if a != b:
if parent[a] < parent[b]:
parent[b] = a
else:
parent[a] = b
return True
return False
graph.sort(key=lambda x: x[2])
ret = 0
for u, v, w in graph:
# ๊ฐ์ ์งํฉ์ด ์๋๋ผ๋ฉด ๊ฒฐ๊ณผ๊ฐ์ ๋น์ฉ ์ถ๊ฐ
if union(u, v):
ret += w
print(ret)
ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ ์ฌ์ฉ (heapq, ๋ฆฌ์คํธ)
ํ๋ฆผ์ ์ฌ์ฉํ ๊ฒฝ์ฐ ๋ชจ๋ ๊ฐ์ ๋ค์ ์ ์ฅํด๋ฌ์ผํ๋ค.
ํ์ฌ ์ ํ๋ ๋
ธ๋๋ฅผ ๊ธฐ์ค์ผ๋ก ์ธ์ ๋
ธ๋๋ค์ ํ์ํด์ผํ๊ธฐ ๋๋ฌธ์ด๋ค.
์ด ์ธ์๋ ์ญ์ ์ผ๋ฐ์ ์ธ ํ๋ฆผ ํ์ด์ ๋์ผํ๋ค.
(ํ์ ์ฌ์ฉํ ๊ฒ์ธ์ง, ๋ฆฌ์คํธ๋ฅผ ์ฌ์ฉํ ๊ฒ์ธ์ง๋ ๊ฐ์ธ ์ ํ)
ํ์ ์ฌ์ฉํ์๋์ ๊ณผ์ ์ ์๋์ ๊ฐ๋ค.
๋จผ์ ์ต์๊ฐ๋ค์ ์ ์ฅํ ๋ฆฌ์คํธdistance๋ฅผ ์์ฑํด์ฃผ๊ณ , heap์ ์์์ ๋
ธ๋์ ๊ฐ(0)์ ์ง์ด๋ฃ๋๋ค.
- heap์์ (๋น์ฉ, ๋ ธ๋)๋ฅผ ๊บผ๋ธ๋ค.
- ์ด๋ฏธ ๋ฐฉ๋ฌธํ ๋ ธ๋๋ผ๋ฉด ๋์ด๊ฐ๋ค.
- ์๋๋ผ๋ฉด ๋ฐฉ๋ฌธ์ฒ๋ฆฌ ํ ์๋์ ๊ณผ์ ์ ์คํํ๋ค.
- ์ ์ฒด ๋น์ฉ์ ๊บผ๋ธ ๋น์ฉ์ ๋ํด์ค
- ํ์ธํ ๋
ธ๋์ ์๋ฅผ +1 ์นด์ดํธ
- ์นด์ดํ ํ ๋ ธ๋์ ์๊ฐ N๊ฐ์ผ๊ฒฝ์ฐ break
- ํ์ฌ ๋
ธ๋์ ์ฐ๊ฒฐ๋ ์ธ์ ๋
ธ๋๋ค์ ํ์ (๋น์ฉ, ๋
ธ๋๋ฒํธ)
- ์ด๋ฏธ ๋ฐฉ๋ฌธํ ๋ ธ๋๊ฑฐ๋ ํ์ฌ๊น์ง์ ์ต์๊ฐ์ด ๋น์ฉ๋ณด๋ค ์๊ฑฐ๋ ๊ฐ๋ค๋ฉด ๋์ด๊ฐ
- ์๋๋ผ๋ฉด, heap์
(๋น์ฉ, ๋ ธ๋๋ฒํธ)ํํ๋ก ์ถ๊ฐ - ํด๋น ๋น์ฉ๊ฐ์ผ๋ก ์ต์๊ฐ ๊ฐฑ์ =>
distance[๋ ธ๋๋ฒํธ] = ๋น์ฉ
์ ์ฒด ์ฝ๋
# 2. ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ ํ์ด (heapq ์ฌ์ฉ)
# ๋ฉ๋ชจ๋ฆฌ: 74296KB / ์๊ฐ: 392ms
from sys import stdin
from heapq import heappush, heappop
input = stdin.readline
N = int(input())
graph = [tuple(map(int, input().split())) for _ in range(N)]
distance = [float("inf")] * N # ํ์ฌ๊น์ง ๋ง๋ค์ด์ง MST์ ์ฐ๊ฒฐ๋ ์ ์๋ ์ต์ ๋น์ฉ
visited = [False] * N
ret = cnt = 0
heap = [(0, 0)]
while heap:
cost, node = heappop(heap)
if visited[node]: # ์ด๋ฏธ ๋ฐฉ๋ฌธํ ๋
ธ๋๋ผ๋ฉด ํจ์ค
continue
visited[node] = True
ret += cost
cnt += 1
# ์นด์ดํ
ํ ๋
ธ๋์ ๊ฐฏ์๊ฐ N๊ฐ๋ผ๋ฉด break
if cnt >= N:
break
# ๋ค์ ๋
ธ๋์ ๋ฒํธ, ๋น์ฉ
for nxt, cost in enumerate(graph[node]):
if visited[nxt] or distance[nxt] <= cost: # ์ด๋ฏธ ๋ฐฉ๋ฌธํ ๋
ธ๋๊ฑฐ๋ ์ต์๊ฐ์ด cost๋ณด๋ค ์๋ค๋ฉด ๋์ด๊ฐ๋ค.
continue
heappush(heap, (cost, nxt))
distance[nxt] = cost
print(ret)
๋ค์์ ๋๊ฐ์ ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ์ด์ง๋ง, ๋ฆฌ์คํธ๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ์์ด๋ค.
๊ธฐ๋ณธ์ ์ธ ํ๋ก์ฐ๋ ํ ์ฌ์ฉ ๋ฒ์ ๊ณผ ๊ฐ๋ค.
๋ค๋ง ๋ฆฌ์คํธ๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ ๋งค๋ฒ "ํ์ฌ MST์์ ์ฐ๊ฒฐํ ์ ์๋ ๋ ธ๋ ์ค ์ต์๋น์ฉ์ธ ๋ ธ๋"๋ฅผ ํ์ํด์ผํ๋ค.
- ์๋ก ์ฐ๊ฒฐํ ๋ ธ๋ ํ์
- ํ์ฌ MST๋ฅผ ๊ธฐ์ค์ผ๋ก ์ฐ๊ฒฐ๋น์ฉ์ด ์ต์์ธ ๋ ธ๋ ์ ํ
- ํด๋น ๋
ธ๋๋ฅผ MST์ ์ถ๊ฐ
- ํด๋น ๋ ธ๋ ๋ฐฉ๋ฌธ์ฒ๋ฆฌ, ์ ์ฒด๋น์ฉ์ ์ฐ๊ฒฐ ๋น์ฉ ์ถ๊ฐ
- ์ถ๊ฐํ ๋
ธ๋์ ์ธ์ ๋
ธ๋๋ค ํ์
- ์ด๋ฏธ ๋ฐฉ๋ฌธํ ๋ ธ๋๊ฑฐ๋ ํ์ฌ๊น์ง์ ์ต์๊ฐ์ด ๋น์ฉ๋ณด๋ค ์๊ฑฐ๋ ๊ฐ๋ค๋ฉด ๋์ด๊ฐ
- ์๋๋ผ๋ฉด ์ต์๊ฐ ๊ฐฑ์
์ ์ฒด ์ฝ๋
# 3. ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ ํ์ด (๋ฆฌ์คํธ ์ฌ์ฉ)
# ๋ฉ๋ชจ๋ฆฌ: 71096KB / ์๊ฐ: 504ms
from sys import stdin
input = stdin.readline
N = int(input())
graph = [tuple(map(int, input().split())) for _ in range(N)]
distance = [float("inf")] * N # MST์ ํฌํจ๋ ๋
ธ๋๋ค๋ก๋ถํฐ ๊ฐ ๋
ธ๋๊น์ง์ ์ต์ ์ฐ๊ฒฐ๋น์ฉ
visited = [False] * N
# 0๋ฒ์งธ ๋
ธ๋์ ๊ฐ์ 0์ผ๋ก ์ค์ ํ MST ์์ฑ ์์
distance[0] = 0
ret = 0
# ๋ชจ๋ ๋
ธ๋๋ฅผ MST์ ํฌํจ์ํฌ๋๊น์ง ๋ฐ๋ณต
for _ in range(N):
# ํ์ฌ MST์ ์ฐ๊ฒฐํ ์ ์๋ ๊ฐ์ฅ ์์ ๋น์ฉ์ ๋
ธ๋๋ฅผ ์ฐพ์์ผํจ.
min_dis = float("inf")
min_node = -1
for i in range(N):
if not visited[i] and distance[i] < min_dis:
min_dis = distance[i]
min_node = i
# MST์ ํฌํจ์ํจ ๋ค ๊ฒฐ๊ณผ๊ฐ์ ๋น์ฉ ์ถ๊ฐ
visited[min_node] = True
ret += min_dis
# ์๋ก ์ถ๊ฐ๋ ๋
ธ๋๋ก๋ถํฐ ๋ค๋ฅธ ๋
ธ๋๋ค๊น์ง์ ์ต์ ์ฐ๊ฒฐ ๋น์ฉ ๊ฐฑ์
for nxt in range(N):
if visited[nxt] or distance[nxt] <= graph[min_node][nxt]:
continue
distance[nxt] = graph[min_node][nxt]
print(ret)
๊ธฐ๋ณธ์ ์ธ ์ต์ ์ ์ฅ ํธ๋ฆฌ ๋ฌธ์ ๋ค. ์ฐ์ตํ๊ธฐ ์ข์๋ฏ.