Algoogle

Algorithm for Programming Contest

PKU 1986 Distance Queries

Category: PKU Tag: binary-indexed-tree, lowest-common-ancestor, segment-tree

Distance Queries

問題概要


全域木になるようなグラフが与えられ, 各枝には重みが付いている.
頂点数はN(<=40,000)
そのとき, K(<=10,000)回クエリが与えられる. 各クエリでは2点が与えられるのでその2点の木の上での距離を答えよ.

解法


愚直にやると, 各クエリでO(N)かかるので全体でO(NK)で話にならない.
木の上での2点の距離は2点のLCAを仲介にして距離を求めることができる.
DFSで頂点を全てみるときの順番(戻るときも含める)に各点のDFSの深さを保存しておく.
そうすると2点の間の深さの最小を取ってくればそれがLCAになる.
明らかにRMQなのでSegment Treeで実装する.

また, 各点からその各祖先までの距離もそのDFSの順に求める. 頂点に入るときに辺の重みを足し, 戻るときに引けばBITで区間の距離の総和が出せる(ただし各祖先までの距離しか出せない).
よって2点の距離はそのLCAを介することで求められる.

DFSの順番について, 例えば以下の図のように探索するとき頂点の列は
1 2 5 2 6 2 1 3 1 4 1
となり, それに対応する深さの列は
0 1 2 1 2 0 1 0 1 0
となり, 各頂点について頂点に入るときの列の場所を覚えておけばその区間を見るだけでよいことがわかる

コード


(1986.cpp) download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <cstdio>
#include <vector>

using namespace std;
struct edge{int to, cst;};
vector<edge> G[40010];

const int inf = 1e9;
int N, M, K, bit_V, S;
int bit[80010], id[40010];
int vs[80010], depth[80010], st[(1<<18)+1];

void add(int i, int x){
    while(i <= bit_V){
        bit[i] += x;
        i += i&-i;
    }
}
int sum(int i){
    int ret = 0;
    while(i > 0){
        ret += bit[i];
        i -= i&-i;
    }
    return ret;
}

void dfs(int v, int p, int d, int &k){
    id[v] = k;
    vs[k] = v;
    depth[k++] = d;
    for(int i = 0; i < G[v].size(); i++){
        if(G[v][i].to != p){
            add(k,G[v][i].cst);
            dfs(G[v][i].to, v, d+1, k);
            vs[k] = v;
            depth[k++] = d;
            add(k,-G[v][i].cst);
        }
    }
}

inline int min_id(int a, int b){
    if(depth[a] < depth[b]) return a;
    return b;
}

int query(int a, int b, int k, int l, int r){
    if(r <= a or b <= l) return bit_V+1;
    if(a <= l and r <= b) return st[k];
    int m = (l + r) / 2;
    int vl = query(a, b, k*2+1, l, m);
    int vr = query(a, b, k*2+2, m, r);
    return min_id(vl, vr);
}

void rmq_init(){
    S = 1;
    while(S < bit_V) S <<= 1;
    for(int i = 0; i < 2*S-1; i++) st[i] = bit_V+1;
    depth[bit_V+1] = inf;
    for(int i = 0; i <= bit_V; i++){
        int k = i + S - 1;
        st[k] = i;
        while(k > 0){
            k = (k-1) / 2;
            st[k] = min_id(st[k*2+1], st[k*2+2]);
        }
        /*
        int k = id[i] + S - 1;
        st[k] = id[i];
        while(k > 0){
            k = (k-1) / 2;
            st[k] = min_id(st[k*2+1], st[k*2+2]);
        }
        */
    }
}

void init(){
    bit_V = (N-1)*2;
    int k = 0;
    dfs(0, -1, 0, k);
    rmq_init();
}

int lca(int u, int v){
    return vs[query(min(id[u],id[v]), max(id[u], id[v])+1, 0, 0, S)];
}

int main(){
    scanf("%d%d", &N, &M);
    for(int i = 0; i < M; i++){
        int a, b, c; char d;
        scanf("%d %d %d %c", &a, &b, &c, &d);
        a--; b--;
        G[a].push_back(edge{b,c});
        G[b].push_back(edge{a,c});
    }
    init();
    scanf("%d", &K);
    while(K--){
        int u, v; scanf("%d%d", &u, &v);
        u--; v--;
        int p = lca(u, v);
        printf("%d\n", sum(id[v])+sum(id[u])-2*sum(id[p]));
    }
    return 0;
}

Comments