Algoogle

Algorithm for Programming Contest

AOJ 2270 The L-th Number

Category: AOJ Tag: lowest-common-ancestor, wavelet-matrix

The L-th Number

問題概要


頂点に値が割り当てられてる木について以下のクエリQ個を捌け.

  • 2頂点v,w間のパス上の頂点のうち,l番目に小さいものを出力

解法


全体の方針としては,wavelet行列で列にした木上の2区間(パスをLCAで2つに分割したもの)内のL番目の数字を求める.
全体の計算量は最大値をMとして

適当な頂点を根にしてオイラーツアーして木の頂点の値を列にする.またこのときLCAの準備もしておく.
このとき入る時の値の列pと出る時の値の列qを作る.
p,qはpに0でない値が入っている位置はqでは0で,qに0でない値が入っている位置はpでは0になるようにする.
例えばオイラーツアーしてできる頂点番号の列が

1
1 2 2 3 4 4 3 1

だとして,それぞれ値が頂点番号と同じだとしたら

1
2
p : 1 2 0 3 4 0 0 0
q : 0 0 2 0 0 4 3 1

となる.
このpとqについてのwavelet行列を生成する.
このp,qでパス上の頂点の値の出現回数がわかるようになる.同じ区間で見て,pで増えて,qで減るから.

とする.
オイラーツアーしてできた列は頂点b,とその祖先a間のパスを,と表現できる(途中のいらない部分木は頂点に入るのと出るので打ち消されてる).
ここでとは列上で頂点aに入る位置.
あとはこうしてできた2区間上でL番目に小さい値を求めればよい.
あとは1つの区間でK番目に大きいものを求めるものを応用させればよい(詳しくはコードを参照してほしい).
ざっくり言うと,最上位bitからみて区間内で有効な値のうち1が立ってる数がLより大きいかどうかでwavelet行列の行上の区間を選んでいく.

コード


(2270.cpp) download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include <bits/stdc++.h>
using namespace std;
#define repi(i,a,b) for(int i = (int)(a); i < (int)(b); i++)
#define rep(i,n) repi(i,0,n)

template<int N> class FID {
    static const int bucket = 512, block = 16;
    static char popcount[];
    int n, B[N/bucket+1];
    unsigned short bs[N/block+1], b[N/block+1];

public:
    FID(){}
    FID(int n, bool s[]) : n(n) {
        if(!popcount[1]) for (int i = 0; i < (1<<block); i++) popcount[i] = __builtin_popcount(i);

        bs[0] = B[0] = b[0] = 0;
        for (int i = 0; i < n; i++) {
            if(i%block == 0) {
                bs[i/block+1] = 0;
                if(i%bucket == 0) {
                    B[i/bucket+1] = B[i/bucket];
                    b[i/block+1] = b[i/block] = 0;
                }
                else b[i/block+1] = b[i/block];
            }
            bs[i/block]   |= short(s[i])<<(i%block);
            b[i/block+1]  += s[i];
            B[i/bucket+1] += s[i];
        }
        if(n%bucket == 0) b[n/block] = 0;
    }

    int count(bool val, int r) { return val? B[r/bucket]+b[r/block]+popcount[bs[r/block]&((1<<(r%block))-1)]: r-count(1,r); }
    int count(bool val, int l, int r) { return count(val,r)-count(val,l); }
    bool operator[](int i) { return bs[i/block]>>(i%block)&1; }
};
template<int N> char FID<N>::popcount[1<<FID<N>::block];

template<class T, int N, int D> class wavelet {
public:
    int n, zs[D];
    FID<N> dat[D];

    wavelet(){}
    wavelet(int n, T seq[]) : n(n) {
        T f[N], l[N], r[N];
        bool b[N];
        memcpy(f, seq, sizeof(T)*n);
        for (int d = 0; d < D; d++) {
            int lh = 0, rh = 0;
            for (int i = 0; i < n; i++) {
                bool k = (f[i]>>(D-d-1))&1;
                if(k) r[rh++] = f[i];
                else l[lh++] = f[i];
                b[i] = k;
            }
            dat[d] = FID<N>(n,b);
            zs[d] = lh;
            swap(l,f);
            memcpy(f+lh, r, rh*sizeof(T));
        }
    }
};

const int N = 100100, D = 30;

wavelet<int,2*N,D> a, b;
int n, m, val[N], in[N], dep[N], par[N][20], idx, p[2*N], q[2*N], sz[N];
vector<int> G[N];

void euler_tour(int v)
{
    for (int k = 0; ~par[v][k] and ~par[par[v][k]][k]; k++)
        par[v][k+1] = par[par[v][k]][k];
    p[idx] = val[v];
    in[v] = idx++;
    for(int &w: G[v])
        if(w != par[v][0]) {
            par[w][0] = v;
            dep[w] = dep[v]+1;
            euler_tour(w);
        }
    q[idx++] = val[v];
}

void build()
{
    memset(par,-1,sizeof(par));
    euler_tour(n/2);
    a = wavelet<int,2*N,D>(2*n, p);
    b = wavelet<int,2*N,D>(2*n, q);
}

int lca(int u, int v)
{
    if(dep[u] > dep[v]) swap(u,v);
    int dif = dep[v] - dep[u];
    for (int i = 0; i < 18; i++)
        if(dif&(1<<i)) v = par[v][i];
    if(u == v) return u;
    for (int i = 17; i >= 0; i--)
        if(par[u][i] != par[v][i]) {
            u = par[u][i];
            v = par[v][i];
        }
    return par[u][0];
}

int query(int v, int w, int k)
{
    int u = lca(v,w);
    k = dep[v]+dep[w]-2*dep[u]-k+1;
    int lva = in[u], lvb = in[u],
        rva = in[v]+1, rvb = in[v]+1,
        lwa = in[u]+1, lwb = in[u]+1,
        rwa = in[w]+1, rwb = in[w]+1,
        ret = 0;
    rep(d,D) {
        int lvac = a.dat[d].count(1,lva), lvbc = b.dat[d].count(1,lvb),
            rvac = a.dat[d].count(1,rva), rvbc = b.dat[d].count(1,rvb),
            lwac = a.dat[d].count(1,lwa), lwbc = b.dat[d].count(1,lwb),
            rwac = a.dat[d].count(1,rwa), rwbc = b.dat[d].count(1,rwb);
        int lc = lvac-lvbc+lwac-lwbc,
            rc = rvac-rvbc+rwac-rwbc;
        if(rc-lc > k) {
            lva = lvac+a.zs[d]; lvb = lvbc+b.zs[d];
            rva = rvac+a.zs[d]; rvb = rvbc+b.zs[d];
            lwa = lwac+a.zs[d]; lwb = lwbc+b.zs[d];
            rwa = rwac+a.zs[d]; rwb = rwbc+b.zs[d];
            ret |= 1ULL<<(D-d-1);
        }
        else {
            k -= rc-lc;
            lva -= lvac; lvb -= lvbc;
            rva -= rvac; rvb -= rvbc;
            lwa -= lwac; lwb -= lwbc;
            rwa -= rwac; rwb -= rwbc;
        }
    }
    return ret;
}

void solve()
{
    build();
    while(m--) {
        int v, w, l;
        scanf("%d%d%d", &v, &w, &l);
        printf("%d\n", query(v-1,w-1,l));
    }
}

void input()
{
    scanf("%d%d", &n, &m);
    rep(i,n) scanf("%d", val+i);
    rep(i,n-1) {
        int a, b; scanf("%d%d", &a, &b);
        a--; b--;
        G[a].push_back(b);
        G[b].push_back(a);
    }
}

int main()
{
    input();
    solve();
    return 0;
}

Comments