Problem Solving/백준

14284 : 간선 이어가기 2 - Python

greatwhite 2024. 11. 12. 11:59

14284 : 간선 이어가기 2

접근


처음엔 단순하게 간선과 가중치가 나오길래 최소 신장 트리를 생각했었다. 하지만, 모든 정점을 연결하는 것이 아닌 특정 정점 두 개만 연결시켜야 했고 최소신장트리로는 풀 수 없음을 느꼈다.

 

그 다음으로 생각한 것은 유니온 파인드인데, 이것도 두 정점을 연결했을 때 최소 비용임을 보장할 수 없다.

 

마지막으로 생각한 것은 다익스트라였다. 모든 정점을 다 고려해도 두 정점 간의 최소 비용으로 연결을 보장한다.

간선 연결 과정에서 최악의 형태는 1자형태로 쭉 뻗은 그래프 형태가 되는데, 그렇게 되면 간선이 N - 1개가 된다. 그렇다면 최대 100,000개의 간선에는 중복이 있을 수 밖에 없다. 따라서, 그래프 저장시 가중치가 더 낮은 간선 취하면 될 듯 싶었다.

풀이


풀 때는 중복되는 간선이 많을 수 있어 배열방식으로 최솟값 갱신하는 방식이 더 빠를 것이라고 생각했는데, 의외로 리스트에 모든 간선 다 저장하는 방식이 더 빨랐다.

아마 추측해보면, 배열 방식을 사용했을 때는 각 정점마다 N - 1개의 간선을 모두 확인하고 넘어가기 때문인 것 같다.

전체 코드


배열 방식

import sys
from heapq import heappop, heappush

# print = sys.stdout.write
input = sys.stdin.readline

# 반드시 아래 두 라인 주석 처리 후 제출
f = open("input.txt", "rt")
input = f.readline
# 반드시 위 두 라인 주석 처리 후 제출
INF = float("inf")

N, M = map(int, input().split())
graph = [[INF] * (N + 1) for _ in range(N + 1)]
for i in range(1, N + 1):
    graph[i][i] = 0

for _ in range(M):
    a, b, c = map(int, input().split())
    graph[a][b] = min(graph[a][b], c)
    graph[b][a] = min(graph[b][a], c)

S, T = map(int, input().split())

def main():
    print(dijkstra())

def dijkstra():
    dist = [INF] * (N + 1)
    dist[S] = 0

    pq = [(0, S)]
    while pq:
        cost, to = heappop(pq)
        if to == T:
            return dist[T]

        for next in range(1, N + 1):
            if next == to or graph[to][next] == INF:
                continue

            next_cost = cost + graph[to][next]
            if next_cost >= dist[next]:
                continue
            dist[next] = next_cost
            heappush(pq, (next_cost, next))

if __name__ == "__main__":
    main()

리스트 방식

import sys
from heapq import heappop, heappush

# print = sys.stdout.write
input = sys.stdin.readline

# 반드시 아래 두 라인 주석 처리 후 제출
f = open("input.txt", "rt")
input = f.readline
# 반드시 위 두 라인 주석 처리 후 제출
INF = float("inf")

N, M = map(int, input().split())
graph = [[] for _ in range(N + 1)]

for _ in range(M):
    a, b, c = map(int, input().split())
    graph[a].append((b, c))
    graph[b].append((a, c))

S, T = map(int, input().split())

def main():
    print(dijkstra())

def dijkstra():
    dist = [INF] * (N + 1)
    dist[S] = 0

    pq = [(0, S)]
    while pq:
        current_cost, current = heappop(pq)
        if current == T:
            return dist[T]

        for next, next_cost in graph[current]:
            total = current_cost + next_cost
            if total >= dist[next]:
                continue

            dist[next] = total
            heappush(pq, (total, next))

if __name__ == "__main__":
    main()