Algoogle

Algorithm for Programming Contest

Codeforces 438E The Child and Binary Tree

Category: Codeforces Tag: fft, math

The Child and Binary Tree

問題概要


2分木の各ノードに重みを割り当てるとき, 合計の重さ1〜mのそれぞれになりうる木の数を求めよ.
ただし各重さはのいずれかで, 同じ重さは何度使っても良い.

解法


FFTライブラリのverifyを兼ねてeditorial見ながら解いた.
ある重さの木の場合の数は根の重さとその2つの子の重さの場合によって決定する.
よって木の重さがwになる場合の数f[w]は,

と表せる. またここで次数を重さとみた多項式を考えると

1つ目は答えを係数に持つ多項式, 2つ目は使える重さの多項式が考えられる.
ある重さの木の場合の数の条件からこの2つの多項式は以下の関係を持つことがわかる.

これを解くと

となる. もう片方の解はf[0]が1であることを考えるとありえないことがわかる( などを考えてみるとよい).

あとはこれを実装すればよい. 多項式の平方根は

とすると

となるので順に計算すればよい.
あとは多項式の積をFFT(modの制約から剰余環を使ったものがよい)を用いて計算すること.

コード


(438E.cpp) download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <bits/stdc++.h>
using namespace std;
#define repi(i,a,b) for(int i = (a); i < (b); i++)
#define rep(i,a) repi(i,0,a)
const int MAX = 1<<20;

int extgcd(int a, int b, int &x, int &y) {
    int g = a; x = 1; y = 0;
    if (b != 0) g = extgcd(b, a % b, y, x), y -= (a / b) * x;
    return g;
}

int mod_inverse(int a, int m){
    int x, y;
    if(extgcd(a, m, x, y) != 1) return 0; // unsolvable
    return (m + x % m) % m;
}

int pow_mod(int x, int k, int m) {
    int ret = 1;
    for(x%=m; k>0; x=1LL*x*x%m,k>>=1) if(k&1) ret = 1LL*ret*x%m;
    return ret;
}

const int mod = 7*17*(1<<23)+1;
// mod = a*2^k+1 and prime
// N = 2^e
vector<int> fmt(vector<int> f, bool inv){
    int e, N = f.size();
    assert((N&(N-1))==0 and "f.size() must be power of 2");
    for(e=0;;e++) if(N == (1<<e)) break;
    rep(m,N){
        int m2 = 0;
        rep(i,e) if(m&(1<<i)) m2 |= (1<<(e-1-i));
        if(m < m2) swap(f[m], f[m2]);
    }
    for(int t=1; t<N; t*=2){
        int r = pow_mod(3,(mod-1)/(t*2),mod);
        if(inv) r = mod_inverse(r,mod);
        for(int i=0; i<N; i+=2*t){
            int power = 1;
            rep(j,t){
                int x = f[i+j], y = 1LL*f[i+t+j]*power%mod;
                f[i+j] = (x+y)%mod;
                f[i+t+j] = (x-y+mod)%mod;
                power = 1LL*power*r%mod;
            }
        }
    }
    if(inv) for(int i=0,ni=mod_inverse(N,mod);i<N;i++) f[i] = 1LL*f[i]*ni%mod;
    return f;
}

vector<int> poly_mul(vector<int> f, vector<int> g){
    int N = max(f.size(),g.size())*2;
    f.resize(N); g.resize(N);
    f = fmt(f,0); g = fmt(g,0);
    rep(i,N) f[i] = 1LL*f[i]*g[i]%mod;
    f = fmt(f,1);
    return f;
}

vector<int> poly_inv(vector<int> f){
    int N = f.size();
    vector<int> r(1,mod_inverse(f[0],mod));
    for(int k = 2; k <= N; k <<= 1){
        vector<int> nr = poly_mul(poly_mul(r,r), vector<int>(f.begin(),f.begin()+k));
        nr.resize(k);
        rep(i,k/2) {
            nr[i] = (2*r[i]-nr[i]+mod)%mod;
            nr[i+k/2] = (mod-nr[i+k/2])%mod;
        }
        r = nr;
    }
    return r;
}

const int inv2 = (mod+1)/2;
vector<int> poly_sqrt(vector<int> f){
    int N = f.size();
    vector<int> s(1,1);
    for(int k = 2; k <= N; k <<= 1){
        s.resize(k);
        vector<int> ns = poly_mul(poly_inv(s), vector<int>(f.begin(),f.begin()+k));
        ns.resize(k);
        rep(i,k) s[i] = 1LL*(s[i]+ns[i])*inv2%mod;
    }
    return s;
}

int N, M;
vector<int> c, f;

void solve(){
    c[0] = 1;
    c = poly_sqrt(c);
    c[0] = (c[0]+1)%mod;
    f = poly_inv(c);
    rep(i,M) cout << f[i+1]*2%mod << endl;;
}

void input(){
    cin >> N >> M;
    int n = 1;
    while(n <= M) n <<= 1;
    c.resize(n,0);
    rep(i,N){
        int a; cin >> a;
        if(a > M) continue;
        c[a] = mod-4;
    }
    N = n;
}

int main(){
    input();
    solve();
    return 0;
}

Comments