cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: 合成
(library/polynomial/compose.hpp)

合成

多項式 $\displaystyle f(x) = \sum _ {i = 0} ^ {N - 1} a _ i x ^ i,\ g(x) = \sum _ {i = 0} ^ {M - 1} b _ i x ^ i$ に対して合成 $(f\circ g)(x) \pmod{x ^ K}$ を $O(NK + \sqrt{N}K\log K)$ 時間で計算するライブラリ。

アルゴリズム

以下は全て $\mathrm{mod}\ x ^ K$ で計算する。

$\displaystyle f\circ g = \sum _ {i = 0} ^ {N - 1} a _ i g ^ i$ を計算したい。$B \coloneqq \lceil \sqrt{N} \rceil$ とすれば $\displaystyle f\circ g = \sum _ {i = 0} ^ {B - 1} (g ^ B) ^ i \sum _ {j = 0} ^ {B - 1} a _ {iB + j} g ^ j$ と表せる。ただし $i \geq N$ に対して $a _ i = 0$ とする。

$g ^ 0, g ^ 1, \ldots, g ^ {B - 1}$ および $G\coloneqq g ^ B$ に対して $G ^ 0, G ^ 1, \ldots, G ^ {B - 1}$ を全て ($\mathrm{mod}\ x ^ K$ で) 前計算する。この前計算は $O(\sqrt N K \log K)$ 時間で可能である。

この前計算の結果を用いれば、各 $i=0,1,\ldots,B-1$ に対して、$\displaystyle \sum _ {j = 0} ^ {B - 1}a _ {iB + j} g ^ j \bmod x ^ K$ を $O(\sqrt{N} K)$ 時間で計算できる。従って、$\displaystyle \left(G ^ i \sum _ {j = 0} ^ {B - 1}a _ {iB + j} g ^ j\right) \bmod x ^ K$ の計算は $O(\sqrt{N} K + K \log K)$ 時間で可能である。

以上をまとめると、全体の計算量は $O(NK + \sqrt{N}K\log K)$ 時間となる。

高速化 (1)

$g ^ i,G ^ i$ の前計算において、$\mathbf{FFT}(g)$ や $\mathbf{FFT}(G)$ を $1$ 回しか計算しないことで定数倍高速化を図ることができる。

高速化 (2)

$\lbrack x ^ 0 \rbrack g = 0$ が成り立つ場合、$\displaystyle f\circ g = \sum _ {i = 0} ^ {B - 1} x ^ {Bi} ((g / x) ^ B) ^ i \sum _ {j = 0} ^ {B - 1} a _ {iB + j} x ^ j (g / x) ^ j$ において $\displaystyle \sum _ {j = 0} ^ {B - 1} a _ {iB + j} x ^ j (g / x) ^ j$ は $\mathrm{mod}\ x ^ {N - iB}$ で求めれば十分である。従って、$NK$ に付く係数をおよそ $1/2$ 倍に削減できる。またこの場合は $f\leftarrow f \bmod x ^ K$ としてもよいので、計算量は本質的に改善されて $O(\min(N,K)K + \sqrt{\min(N,K)}K\log K)$ 時間となる。

$\lbrack x ^ 0 \rbrack g \neq 0$ の場合は $f$ を taylor shift して $f \leftarrow f(x + \lbrack x ^ 0 \rbrack g)$ とおき直すことで $\lbrack x ^ 0 \rbrack g = 0$ の場合に帰着できる。帰着に掛かる計算量は $O(N\log N)$ であり、計算量は $O(\min(N,K)K + \sqrt{\min(N,K)}K\log K + N \log N)$ である。なお、$K \leq \log N$ など $K$ が極端に小さい場合は、$N\log N$ の項が支配的になるため帰着を行わない方が高速となる可能性がある。

Verified with

Code

#ifndef SUISEN_POLY_COMPOSE
#define SUISEN_POLY_COMPOSE

#include <cmath>
#include <vector>
#include <atcoder/convolution>

namespace suisen {
    template <typename mint>
    std::vector<mint> compose(const std::vector<mint>& f, std::vector<mint> g, const int n) {
        std::vector<mint> res(n);
        if (n == 0) return res;
        if (f.empty()) return res;

        if (std::find_if(g.begin(), g.end(), [](mint x) { return x != 0; }) == g.end()) return res[0] = f[0], res;

        // taylor shift f(x + [x^0]g)
        const std::vector<mint> fa = [&]{
            const mint a = std::exchange(g[0], 0);
            const int siz_f = f.size();
            
            std::vector<mint> fac(siz_f), fac_inv(siz_f);
            fac[0] = 1;
            for (int i = 1; i <= siz_f - 1; ++i) fac[i] = fac[i - 1] * i;
            fac_inv[siz_f - 1] = fac[siz_f - 1].inv();
            for (int i = siz_f - 1; i >= 1; --i) fac_inv[i - 1] = fac_inv[i] * i;

            std::vector<mint> ec(siz_f), fa(siz_f);
            mint p = 1;
            for (int i = 0; i < siz_f; ++i, p *= a) {
                ec[i] = p * fac_inv[i];
                fa[siz_f - 1 - i] = (i < int(f.size()) ? f[i] : 0) * fac[i];
            }
            fa = atcoder::convolution(fa, ec), fa.resize(siz_f);
            std::reverse(fa.begin(), fa.end());
            for (int i = 0; i < siz_f; ++i) {
                fa[i] *= fac_inv[i];
            }
            if (siz_f > n) fa.resize(n);
            return fa;
        }();

        const int sqn = ::sqrt(f.size()) + 1;

        const int z = [n]{
            int z = 1;
            while (z < 2 * n - 1) z <<= 1;
            return z;
        }();
        const mint iz = mint(z).inv();

        g.erase(g.begin());
        g.resize(z);
        atcoder::internal::butterfly(g);

        auto mult_g = [&](std::vector<mint> a) {
            a.resize(z);
            atcoder::internal::butterfly(a);
            for (int j = 0; j < z; ++j) a[j] *= g[j] * iz;
            atcoder::internal::butterfly_inv(a);
            a.resize(n);
            return a;
        };

        std::vector<std::vector<mint>> pow_g(sqn, std::vector<mint>(n));
        pow_g[0][0] = 1;
        for (int i = 1; i < sqn; ++i) {
            pow_g[i] = mult_g(pow_g[i - 1]);
        }

        std::vector<mint> gl = mult_g(pow_g[sqn - 1]);
        gl.resize(z);
        atcoder::internal::butterfly(gl);

        std::vector<mint> pow_gl(z);
        pow_gl[0] = 1;

        for (int i = 0; i < sqn; ++i) {
            const int off_i = i * sqn;
            const int siz_i = n - off_i;
            if (siz_i < 0) break;
            std::vector<mint> fg(siz_i);
            for (int j = 0; j < sqn; ++j) {
                const int ij = i * sqn + j;
                if (ij >= int(fa.size())) break;

                const mint c = fa[ij];
                const std::vector<mint>& gj = pow_g[j];
                for (int k = 0; k < siz_i - j; ++k) {
                    fg[j + k] += c * gj[k];
                }
            }
            fg.resize(z);
            atcoder::internal::butterfly(pow_gl);
            atcoder::internal::butterfly(fg);
            for (int k = 0; k < z; ++k) {
                fg[k] *= pow_gl[k] * iz;
                pow_gl[k] *= gl[k] * iz;
            }
            atcoder::internal::butterfly_inv(pow_gl);
            atcoder::internal::butterfly_inv(fg);
            for (int k = 0; k < siz_i; ++k) {
                res[off_i + k] += fg[k];
            }
            std::fill(pow_gl.begin() + n, pow_gl.end(), 0);
        }
        return res;
    }
} // namespace suisen


#endif // SUISEN_POLY_COMPOSE
#line 1 "library/polynomial/compose.hpp"



#include <cmath>
#include <vector>
#include <atcoder/convolution>

namespace suisen {
    template <typename mint>
    std::vector<mint> compose(const std::vector<mint>& f, std::vector<mint> g, const int n) {
        std::vector<mint> res(n);
        if (n == 0) return res;
        if (f.empty()) return res;

        if (std::find_if(g.begin(), g.end(), [](mint x) { return x != 0; }) == g.end()) return res[0] = f[0], res;

        // taylor shift f(x + [x^0]g)
        const std::vector<mint> fa = [&]{
            const mint a = std::exchange(g[0], 0);
            const int siz_f = f.size();
            
            std::vector<mint> fac(siz_f), fac_inv(siz_f);
            fac[0] = 1;
            for (int i = 1; i <= siz_f - 1; ++i) fac[i] = fac[i - 1] * i;
            fac_inv[siz_f - 1] = fac[siz_f - 1].inv();
            for (int i = siz_f - 1; i >= 1; --i) fac_inv[i - 1] = fac_inv[i] * i;

            std::vector<mint> ec(siz_f), fa(siz_f);
            mint p = 1;
            for (int i = 0; i < siz_f; ++i, p *= a) {
                ec[i] = p * fac_inv[i];
                fa[siz_f - 1 - i] = (i < int(f.size()) ? f[i] : 0) * fac[i];
            }
            fa = atcoder::convolution(fa, ec), fa.resize(siz_f);
            std::reverse(fa.begin(), fa.end());
            for (int i = 0; i < siz_f; ++i) {
                fa[i] *= fac_inv[i];
            }
            if (siz_f > n) fa.resize(n);
            return fa;
        }();

        const int sqn = ::sqrt(f.size()) + 1;

        const int z = [n]{
            int z = 1;
            while (z < 2 * n - 1) z <<= 1;
            return z;
        }();
        const mint iz = mint(z).inv();

        g.erase(g.begin());
        g.resize(z);
        atcoder::internal::butterfly(g);

        auto mult_g = [&](std::vector<mint> a) {
            a.resize(z);
            atcoder::internal::butterfly(a);
            for (int j = 0; j < z; ++j) a[j] *= g[j] * iz;
            atcoder::internal::butterfly_inv(a);
            a.resize(n);
            return a;
        };

        std::vector<std::vector<mint>> pow_g(sqn, std::vector<mint>(n));
        pow_g[0][0] = 1;
        for (int i = 1; i < sqn; ++i) {
            pow_g[i] = mult_g(pow_g[i - 1]);
        }

        std::vector<mint> gl = mult_g(pow_g[sqn - 1]);
        gl.resize(z);
        atcoder::internal::butterfly(gl);

        std::vector<mint> pow_gl(z);
        pow_gl[0] = 1;

        for (int i = 0; i < sqn; ++i) {
            const int off_i = i * sqn;
            const int siz_i = n - off_i;
            if (siz_i < 0) break;
            std::vector<mint> fg(siz_i);
            for (int j = 0; j < sqn; ++j) {
                const int ij = i * sqn + j;
                if (ij >= int(fa.size())) break;

                const mint c = fa[ij];
                const std::vector<mint>& gj = pow_g[j];
                for (int k = 0; k < siz_i - j; ++k) {
                    fg[j + k] += c * gj[k];
                }
            }
            fg.resize(z);
            atcoder::internal::butterfly(pow_gl);
            atcoder::internal::butterfly(fg);
            for (int k = 0; k < z; ++k) {
                fg[k] *= pow_gl[k] * iz;
                pow_gl[k] *= gl[k] * iz;
            }
            atcoder::internal::butterfly_inv(pow_gl);
            atcoder::internal::butterfly_inv(fg);
            for (int k = 0; k < siz_i; ++k) {
                res[off_i + k] += fg[k];
            }
            std::fill(pow_gl.begin() + n, pow_gl.end(), 0);
        }
        return res;
    }
} // namespace suisen
Back to top page