Algoogle

Algorithm for Programming Contest

JOI 春合宿 2010 Highway

Category: JOI Tag: binary-indexed-tree, lowest-common-ancestor

Highway

問題概要


木が与えられる.
ノードの番号が小さい方から大きい方へいく時を上りといい, 逆の時を下りという.
以下のクエリを処理しろ.

  • ある辺の上りと下りにかかる時間の変更
  • ある2頂点x, yのxからyにいくときにかかる時間の総和の出力

解法


根付き木として考える.
頂点uから頂点vまでの距離を求めたい.
上りと下りの区別がないとして考える.
上りと下りの区別がないとき, 2頂点間の距離はその最小共通祖先wまでの距離の和と考えるとよい.
こうするとオイラーツアーしてやれば区間に部分木が入るので, 辺を降りるときに足して昇るときに引いてやればBITで各頂点からその子孫までの距離をO(log n)で得ることができる.
辺の更新はBITで対応する位置に現在の値との差分を足してやれば良い.

上りと下りの区別がある場合
パスの向きを考えるとu->w->vと進むことになることが分かる.
もっと詳しく見ると
u->wが根に向かって進む方向
w->vが葉に向かって進む方向
この2種類しかないので上りと下りをこれに対応させて2本のBIT, つまり根に向かう方向の距離を求めるのと葉に向かう方向の距離を求めるBITをそれぞれ別に用意すれば良い.

コード


(highway.cpp) download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <bits/stdc++.h>
using namespace std;
typedef pair<int,int> pii;

template<class T> class bit
{
public:
        vector<T> dat;
        int N;

        bit(){}
        bit(int N) : N(N){
                dat.assign(N,0);
        }
        // sum [0,i)
        T sum(int i){
                T ret = 0;
                for(--i; i>=0; i=(i&(i+1))-1) ret += dat[i];
                return ret;
        }
        // sum [i,j)
        T sum(int i, int j){ return sum(j) - sum(i);}
        // add x to i
        void add(int i, T x){ for(; i < N; i|=i+1) dat[i] += x;}
};


int n, m;
vector<pii> G[100010], es;

int dep[100010], par[100010][32], in[100010], out[100010], dist[100010][2];
bit<int> up, down;

void dfs(int v, int u, int d, int &idx)
{
        if(v) {
                dep[v] = d;
                par[v][0] = u;
                int k = 0;
                while(par[par[v][k]][k] >= 0) {
                        par[v][k+1] = par[par[v][k]][k];
                        k++;
                }
                up.add(idx,1);
                down.add(idx,1);
        }
        in[v] = idx++;
        for(auto &e: G[v])
                if(e.first != u) dfs(e.first, v, d+1, idx);

        if(v) {
                up.add(idx,-1);
                down.add(idx,-1);
        }
        out[v] = idx++;
}

void build_tree()
{
        memset(par,-1,sizeof(par));
        up = bit<int>(2*n);
        down = bit<int>(2*n);
        int idx = 0;
        dfs(0,-1,0, idx);
}

int lca(int u, int v)
{
        if(dep[u] > dep[v]) swap(u,v);
        int dif = dep[v] - dep[u];
        for (int i = 0; i < 18; i++)
                if(dif&(1<<i)) v = par[v][i];
        if(u == v) return u;
        for (int i = 17; i >= 0; i--)
                if(par[u][i] != par[v][i]) {
                        u = par[u][i];
                        v = par[v][i];
                }
        return par[u][0];
}

void solve()
{
        build_tree();
        char c;
        int u, v, w, r, s, t;
        while(m--) {
                cin >> c;
                if(c == 'I') {
                        cin >> r >> s >> t;
                        r--;
                        u = es[r].first;
                        v = es[r].second;
                        if(dep[v] > dep[u]) {
                                swap(u,v);
                                swap(s,t);
                        }
                        up.add(in[u], -dist[r][0]+s);
                        up.add(out[u], dist[r][0]-s);
                        down.add(in[u], -dist[r][1]+t);
                        down.add(out[u], dist[r][1]-t);
                        dist[r][0] = s;
                        dist[r][1] = t;
                }
                else {
                        cin >> u >> v;
                        u--; v--;
                        w = lca(u,v);
                        cout << up.sum(in[w]+1,in[u]+1)+down.sum(in[w]+1,in[v]+1) << endl;
                }
        }
}

void input()
{
        cin >> n >> m;
        for (int i = 0; i < n-1; i++) {
                int p, q; cin >> p >> q;
                dist[i][0] = dist[i][1] = 1;
                p--; q--;
                G[p].push_back(pii(q,i));
                G[q].push_back(pii(p,i));
                es.push_back(pii(p,q));
        }
}

int main()
{
        cin.tie(0);
        cin.sync_with_stdio(0);
        input();
        solve();
        return 0;
}

Comments