Loading [MathJax]/jax/output/HTML-CSS/jax.js

문제 풀이

 

23840번: 두 단계 최단 경로 4

첫째 줄에 정점의 수 N(10 ≤ N ≤ 100,000), 간선의 수 M(10 ≤ M ≤ 300,000)이 주어진다. 다음 M개 줄에 간선 정보 u v w가 주어지며 도시 u와 도시 v 사이의 가중치가 정수 w인 양방향 도로를

www.acmicpc.net

중간 경유지(P)를 최대 20개까지 가질 때 시작점(X)에서 끝점(Z)까지 최단경로를 찾는 문제이다. 내가 들었던 대학교 학부 알고리즘개론 수업의 과제로 나오기도 했다.

출발점->P개의 경유지->도착점까지 가는 경우의 수는 P!개이다. 시간제한이 7초이고 20!>1018이므로 모든 경우의 수를 고려하면 시간 초과가 발생한다.  

 

1. Bitmask DP를 활용하자!

Naive한 풀이가 P!일 때는 Bitmask DP를 생각해볼만 하다. Bitmask DP는 O(P!)O(Pa2P)로 줄여주는 테크닉이라고 생각하면 된다.

먼저 Bitmask DP의 대표문제이자 이 문제와 밀접한 연관이 있는 외판원 순회(Traveling Salesman Problem) 을 이해하고 오는 것이 좋다. 사실상 이 문제는 TSP, Dijkstra의 단순 개념혼합인 셈이다.

먼저, 다익스트라 알고리즘의 시간복잡도는 O(NlogN)이다. P개의 경유점, 출발점, 도착점에 한해서만 다익스트라 알고리즘을 돌려 P+2개의 정점끼리 이동하는 최단경로를 구하면 Bitmask DP를 사용할 준비는 끝났다. 시간복잡도는 O((P+2)NlogN)이다.

 

2. P+1개를 그대로 Bitmask DP에 넣는다면?

외판원 순회 문제의 시간복잡도는 O(N22N)이다. 출발점에서 P+1개의 정점으로 가는 최단경로를 DP 테이블에 보관하고, P+1개의 정점을 그대로 TSP에 적용하면 최대 212221=924844032로, 제한시간 7초를 넘어갈 수 있다. (실제로 시간초과가 나온다.)

우리가 생각해야 하는 것은 도착점도 한정되어 있다는 것이다. 경유점 P개에 대해서만 DP테이블에서 Bitmask DP를 진행하고, DP[1<<P1][Q]+w[Q][Z](0Q<P)의 최솟값을 구해준다면 시간복잡도는 O(P22P)가 된다.

 

3. 시간 복잡도

다익스트라 알고리즘의 시간복잡도와 Bitmask DP를 진행하는 시간복잡도를 합치면 O((P+2)NlogN+P22P)로, 아래의 소스코드는 2580ms에 통과한다.

 

소스 코드

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

vector<vector<ll>> arr;
vector<vector<pair<int,ll>>> edge;
vector<bool> visited;
vector<ll> dist;

void Dijkstra (int n, int start, int p) {
    dist[start] = 0;
    priority_queue<pair<ll,int>, vector<pair<ll,int>>, greater<pair<ll,int>>> pq;
    pq.push({0,start});
    
    while (!pq.empty()) {
        int s = pq.top().second;
        pq.pop();
        visited[s] = true;
        for (int i = 0; i < edge[s].size(); i++) {
            int e = edge[s][i].first;
            ll c = edge[s][i].second;
            if (!visited[e] && dist[e] > dist[s]+c) {
                pq.push({dist[s]+c, e});
                dist[e] = dist[s]+c;
            }
        }
    }
    
    for (int i = 0; i < n; i++) {
        if (dist[i] == LLONG_MAX) arr[p][i] = -1;
        else arr[p][i] = dist[i];
    }
}

int main() {
    ios_base :: sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
    
    int n, m; cin >> n >> m;
    edge.resize(n, vector<pair<int,ll>>());
    visited.resize(n, false);
    dist.resize(n, LLONG_MAX);
    
    for (int i = 0; i < m; i++) {
        int a,b; ll c; cin >> a >> b >> c;
        edge[a-1].push_back({b-1, c});
        edge[b-1].push_back({a-1, c});
    }
    int s, e; cin >> s >> e;
    s--; e--;
    
    int p; cin >> p;
    int route[p];
    for (int i = 0; i < p; i++) cin >> route[i];
    arr.resize(p, vector<ll>(n, 0));
    
    for (int i = 0; i < p; i++) {
        Dijkstra(n, route[i]-1, i);
        route[i]--;
        for (int j = 0; j < n; j++) {
            visited[j]=false; dist[j]=LLONG_MAX;
        }
    }
    
    ll dp[p][1<<p];
    for (int i = 0; i < p; i++) {
        for (int j = 0; j < (1<<p); j++) dp[i][j]=LLONG_MAX;
    }
    for (int i = 0; i < p; i++) {
        if (arr[i][s] != -1) dp[i][1<<i] = arr[i][s];
    }
    int bitcount[1<<p];
    for (int i = 0; i < (1<<p); i++) {
        int c = i;
        int cnt = 0;
        while (c != 0) {
            if (c&1) cnt++;
            c/=2;
        }
        bitcount[i] = cnt;
    }
    vector<int> bitc[p+1];
    for (int i = 0; i < (1<<p); i++) {
        bitc[bitcount[i]].push_back(i);
    }
    for (int q = 1; q < p; q++) {
        for (int j = 0; j < bitc[q].size(); j++) {
                int lj = bitc[q][j];
                for (int i = 0; i < p; i++) {
                    if (dp[i][lj] != LLONG_MAX) {
                        for (int k = 0; k < p; k++) {
                            if (arr[i][route[k]]!=-1 && (lj != (lj|(1<<k)))) {
                                if (dp[k][lj|(1<<k)] > dp[i][lj]+arr[i][route[k]]) {
                                    dp[k][lj|(1<<k)] = dp[i][lj]+arr[i][route[k]];
                                }
                            }
                        }
                    }
                }
        }
    }
    ll ans = LLONG_MAX;
    for (int i = 0; i < p; i++) {
        if (dp[i][(1<<p)-1] != LLONG_MAX && arr[i][e] != -1 && dp[i][(1<<p)-1]+arr[i][e] < ans) ans = dp[i][(1<<p)-1]+arr[i][e];
    }
    if (ans == LLONG_MAX) cout << "-1";
    else cout << ans;
}

 

'BOJ' 카테고리의 다른 글

[백준 1078] 뒤집음  (0) 2022.06.27
[백준 22021] 자동분무기  (0) 2022.06.26
[백준 9661] 돌 게임7  (1) 2022.06.22
[백준 1007] 벡터 매칭 - 백트래킹  (0) 2022.06.20
[백준 20149] 선분 교차 3  (0) 2022.06.18

+ Recent posts