๋ฌธ์
https://www.acmicpc.net/problem/1311
๋ฐฑ์ค ๋จ๊ณ๋ณ๋ก ํ์ด๋ณด๊ธฐ - ๋์ ๊ณํ๋ฒ 3
ํ์ด
N๋ช ์ ์ฌ๋๊ณผ N๊ฐ์ ์ผ์ด ์๋ค. ๊ฐ ์ฌ๋์ ์ผ์ ํ๋ ๋ด๋นํด์ผ ํ๊ณ , ๊ฐ ์ผ์ ๋ด๋นํ๋ ์ฌ๋์ ํ ๋ช ์ด์ด์ผ ํ๋ค. ๋ํ, ๋ชจ๋ ์ฌ๋์ ๋ชจ๋ ์ผ์ ํ ๋ฅ๋ ฅ์ด ์๋ค.
์ฌ๋์ 1๋ฒ๋ถํฐ N๋ฒ๊น์ง ๋ฒํธ๊ฐ ๋งค๊ฒจ์ ธ ์์ผ๋ฉฐ, ์ผ๋ 1๋ฒ๋ถํฐ N๋ฒ๊น์ง ๋ฒํธ๊ฐ ๋งค๊ฒจ์ ธ ์๋ค.
Dij๋ฅผ i๋ฒ ์ฌ๋์ด j๋ฒ ์ผ์ ํ ๋ ํ์ํ ๋น์ฉ์ด๋ผ๊ณ ํ์ ๋, ๋ชจ๋ ์ผ์ ํ๋๋ฐ ํ์ํ ๋น์ฉ์ ์ต์๊ฐ์ ๊ตฌํ๋ ํ๋ก๊ทธ๋จ์ ์์ฑํ์์ค.
๋์ ๊ณํ๋ฒ 3 ๋ฌธ์ ๋ ๋ชจ๋ ๋นํธ๋ง์คํน์ ํตํด ํ์ด์ผ ํ๋ค.
์์ ์ ๋ ฅ๊ฐ๋๋ก N = 3์ด๋ผ๊ณ ๊ฐ์ ํด๋ณด์.
์ฐ๋ฆฌ๋ 3๊ฐ์ ๋นํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ผ์ ์งํ ์ํฉ์ ๋ํ๋ด์ผ ํ๋ค.
000 => ์๋ฌด ์ผ๋ ์ ํ๋์ง ์์ ์ํฉ
001 => ์ฒซ๋ฒ์งธ ์ผ๋ง ๋ฐฐ์ ๋ ์ํฉ
101 => ์ฒซ๋ฒ์งธ, ์ธ๋ฒ์งธ ์ผ์ด ๋ฐฐ์ ๋ ์ํฉ
์ฆ, ๊ฐ๋ฅํ ๊ฒฝ์ฐ์ ์๋ 1 << 3 = 2^3 = 8 ์ด๋ค. (000~111)
dp์๋ ๋ชจ๋ ์ํ์์ ์ผ์ ๋ฐฐ์ ํ๋๋ฐ ๋๋ ์ต์ ๋น์ฉ์ ์ ์ฅํ๋ค.
๋ฐ๋ผ์ dp = [์์ฒญํฐ๊ฐ] * (1 << 3) ์ผ๋ก ์ด๊ธฐํ ํ ๋ค ์งํํ๋ค.
๊ทธ ๋ค์ ๋ฐ๋ณต๋ฌธ์ ํตํด, ํ์ฌ ์ด๋ค ์ฌ๋์ด ์ผ์ ๋งก์ ์ฐจ๋ก์ธ์ง & ์ด๋ค ์ผ์ด ๋จ์์๋์ง ๋ฅผ ํ๋จํ๋ค.
๋ง์ฝ ํ์ฌ ์ผ์ ๋งก์ ์ ์๋ ์ํฉ์ด๋ผ๋ฉด, dp[ํ์ฌ ์ผ์ ๋งก์์๋์ ๋นํธ๊ฐ] = min(๊ธฐ์กด dp๊ฐ, ์ผ์ ๋งก์ง ์์์๋์ dp๊ฐ + ํ์ฌ ์ผ์ ๋น์ฉ)์ผ๋ก ๋ ์์๊ฐ์ผ๋ก ๊ฐฑ์ ํ๋ค.
์ฐธ๊ณ ๋ก ๋๋ ๋ฐ๋ณต๋ฌธ์ ํตํด ํ์๋๋ฐ, ์ฌ๊ทํจ์๋ก ํ์ด์ผ Python3๋ก ํต๊ณผํ ์ ์๋ค.
(๋๋ ํ๊ฐ๋ฆฌ์ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํด์ผ ํ๋ค. ์ด ์๊ณ ๋ฆฌ์ฆ์ ๋์ค์ ๊ณต๋ถํด๋ด์ผํ ๋ฏ...)
์๋๋ PyPy3๋ก ํต๊ณผํ ์ฝ๋๋ค.
# (PyPy3) ๋ฉ๋ชจ๋ฆฌ: 118788KB / ์๊ฐ: 820ms
from sys import stdin
input = stdin.readline
INF = int(1e9)
N = int(input())
D = [list(map(int, input().split())) for _ in range(N)]
dp = [INF] * (1 << N) # ๊ฒฝ์ฐ์ ์๊ฐ 000~111๊น์ง 8๊ฐ์ด๋ฏ๋ก 1 << 3 => 8
dp[0] = 0
for i in range(1 << N):
x = bin(i).count("1")
for j in range(N):
if not (i & (1 << j)): # ์ด๋ฏธ ์ ํ๋ ์ผ์ด ์๋๋ผ๋ฉด
nxt = i | (1 << j)
dp[nxt] = min(dp[nxt], dp[i] + D[x][j])
print(dp[(1 << N) - 1])
์๋ฅผ๋ค์ด i๊ฐ 6์ด๋ฉด 110์ด๊ณ , x๋ 2๊ฐ ๋๋ค.
๋์ด๊ฐ์ j๊ฐ 0์ผ๋ 1 << 0 = 001 ์ด๋ฏ๋ก 110 & 001 = 0 ์ด ๋ผ์ ์กฐ๊ฑด๋ฌธ์ ํต๊ณผํ๋ค.
nxt = 110 | 001 = 111 = 7
dp[7(111)] = min(dp[7(111)], dp[6(110)] + D[2][0]) ์ผ๋ก ์ต์๋น์ฉ์ ๊ฐฑ์ ํ๋ ๋ฐฉ์์ด๋ค.
(D[2][0] = 3๋ฒ์งธ ์ฌ๋์ด 1๋ฒ์งธ ์ผ์ ํ ๋์ ๋น์ฉ)
์๋๋ Python3๋ก ํต๊ณผํ ๋ถ์ ์ฝ๋๋ค.
๋ฐ๋ณต๋ฌธ์ด๋ ์ฌ๊ท๋์ ์ฐจ์ด์ธ๋ฏ ์ถ๋ค.
์ถ์ฒ๐ https://ji-gwang.tistory.com/446
# ์๋๋ Python3๋ก ํต๊ณผ๋๋ ์ฝ๋๋ค. ์ฌ๊ท๋ฅผ ์ฌ์ฉ.
# ์ถ์ฒ: https://ji-gwang.tistory.com/446
# ๋ฉ๋ชจ๋ฆฌ: 72080KB / ์๊ฐ: 4796ms
import sys
input = sys.stdin.readline
def dfs(row, visit):
if row == N:
return 0
if visited[visit] != -1:
return visited[visit]
ret = 1000000000
for i in range(N):
if (visit & (1 << i)) != 0: # ํน์ ๋นํธ๊ฐ ์ผ์ ์๋ค๋ฉด
continue
ret = min(ret, dfs(row + 1, (visit | (1 << i))) + tasks[row][i])
visited[visit] = ret
return visited[visit]
N = int(input())
tasks = [list(map(int, input().split())) for _ in range(N)]
visited = [-1] * (1 << N)
print(dfs(0, 0))
๊ฐ row(์ฌ๋ ์์)๋๋ก ์ต์ ๋น์ฉ์ ์ ํํ๋ ๋ฐฉ์์ด๋ค. i๊ฐ row, dp๊ฐ visited, j๊ฐ i๋ก ๋ฐ๋ ์ํ๋ผ๊ณ ๋ณด๋ฉด ๋๋ค.
ํด๋น ์ฌ๋์ด N๊ฐ์ ์ผ ์ค ํ๋๋ฅผ ๋งก์ ๊ฒฝ์ฐ๋ฅผ ์ฒดํฌํ๋ค.
ret์ ํ์ฌ ์ผ์ด ์๋ ์ด์ ์ผ์ ์ ํํ์๋์ ๊ฒฝ์ฐ์ ๊ฐ(์ฒ์์ ์์ฒญํฐ๊ฐ)์ ๋ํ๋ด๊ณ ,
dfs(๋ค์ ์ฌ๋, ํ์ฌ ์ผ์ ์ ํํ์ ๊ฒฝ์ฐ์ ์ต์ข ๊ฐ)๊ณผ ๋น๊ตํด ๋ ์ ์ ๊ฐ์ผ๋ก ๊ฐฑ์ ํ๋ค.
๊ทธ๋ฆฌ๊ณ dfs()๋ฅผ ํตํด ์ฒดํฌํ ํด๋น ๊ฒฝ์ฐ์ ๊ฐ์ด ์ด๋ฏธ ์กด์ฌํ๋ค๋ฉด, ๋ค์ ์ฒดํฌํ์ง ์๊ณ ๋ฐ๋ก ๊ธฐ์กด๊ฐ์ ๋ฐํํด์ค๋ค.
์ฌ๊ธฐ์๋ ํจ์จ ์ฐจ์ด๊ฐ ๋ฌ๋๊ฒ๊ฐ๋ค.
๋ง์ง๋ง์ผ๋ก ํ๊ฐ๋ฆฌ์ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ ๋ฐฉ์์ด๋ค.
์คํ์๊ฐ ์ฐจ์ด๊ฐ ๋ง๋์๋๋คใ ใ ใ ์ค๋ง์ด๊น
# ์คํ์๊ฐ 48ms์ธ ์ฝ๋. Hungarian ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํจ.
from sys import stdin
from itertools import product
input = stdin.readline
N = int(input())
table = [list(map(int, input().split())) for _ in range(N)]
def hungarian(table):
N = len(table)
match_x, match_y = [None] * N, [None] * N
label_x, label_y = list(map(max, table)), [0] * N
is_free_x, is_free_y = (lambda x: match_x[x] == None), (
lambda y: match_y[y] == None
)
gap = lambda i, j: label_x[i] + label_y[j] - table[i][j]
while None in match_x:
tree_x, tree_y = [None] * N, [None] * N
S, T = [False] * N, [False] * N
u = next(filter(is_free_x, range(N)))
S[u] = True
slack, slack_x = [gap(u, j) for j in range(N)], [u] * N
while True:
try:
y = next(filter(lambda j: (slack[j] == 0) and (not T[j]), range(N)))
except:
min_gap = min(v for v, b in zip(slack, T) if not b)
for i in range(N):
label_x[i] -= min_gap * int(S[i])
label_y[i] += min_gap * int(T[i])
slack[i] -= min_gap * int(not T[i])
else:
if is_free_y(y):
tree_y[y] = slack_x[y]
while y != None:
x = tree_y[y]
match_y[y], match_x[x], y = x, y, match_x[x]
break
else:
z = match_y[y]
tree_x[z], tree_y[y] = y, slack_x[y]
S[z], T[y] = True, True
for i in range(N):
slack[i], slack_x[i] = min(
(gap(z, i), z), (slack[i], slack_x[i])
)
return list(enumerate(match_x)), sum(label_x) + sum(label_y)
for i, j in product(range(N), repeat=2):
table[i][j] *= -1
_, ans = hungarian(table)
print(-ans)
์ฒ์ ๋ค์ด๋ณด๋ ์๊ณ ๋ฆฌ์ฆ์ด๋ผ์ ์ข ๊ณต๋ถํด์ผ๋ ๋ฏํ๋ค.