cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: test/src/convolution/relaxed_convolution_ntt/convolution_mod.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"

#include <iostream>

#include <atcoder/modint>
#include <atcoder/convolution>

using mint = atcoder::modint998244353;

std::istream& operator>>(std::istream& in, mint& a) {
    long long e; in >> e; a = e;
    return in;
}

std::ostream& operator<<(std::ostream& out, const mint& a) {
    out << a.val();
    return out;
}

#include "library/convolution/relaxed_convolution_ntt.hpp"

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::size_t n, m;
    std::cin >> n >> m;

    std::vector<mint> a(n), b(m);
    for (auto& e : a) std::cin >> e;
    for (auto& e : b) std::cin >> e;

    suisen::RelaxedConvolutionNTT<mint> conv;

    for (std::size_t i = 0; i < n + m - 1; ++i) {
        conv.append(i < a.size() ? a[i] : 0, i < b.size() ? b[i] : 0);
    }
    auto c = conv.get();
    for (std::size_t i = 0; i < n + m - 1; ++i) {
        std::cout << c[i] << " \n"[i == n + m - 2];
    }
    return 0;
}
#line 1 "test/src/convolution/relaxed_convolution_ntt/convolution_mod.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"

#include <iostream>

#include <atcoder/modint>
#include <atcoder/convolution>

using mint = atcoder::modint998244353;

std::istream& operator>>(std::istream& in, mint& a) {
    long long e; in >> e; a = e;
    return in;
}

std::ostream& operator<<(std::ostream& out, const mint& a) {
    out << a.val();
    return out;
}

#line 1 "library/convolution/relaxed_convolution_ntt.hpp"



#line 5 "library/convolution/relaxed_convolution_ntt.hpp"

namespace suisen {
    // reference: https://qiita.com/Kiri8128/items/1738d5403764a0e26b4c
    template <typename mint>
    struct RelaxedConvolutionNTT {
        RelaxedConvolutionNTT(): _n(0), _f{}, _g{}, _h{} {}

        mint append(const mint& fi, const mint& gi) {
            static constexpr int threshold_log = 6;
            static constexpr int threshold = 1 << threshold_log;
            static constexpr int threshold_mask = threshold - 1;

            ++_n;
            _f.push_back(fi), _g.push_back(gi);

            const int q = _n >> threshold_log, r = _n & threshold_mask;
            if (r == 0) {
                if (q == (-q & q)) {
                    std::vector<mint> f_fft = _f;
                    std::vector<mint> g_fft = _g;
                    f_fft.resize(2 * _n);
                    g_fft.resize(2 * _n);
                    atcoder::internal::butterfly(f_fft);
                    atcoder::internal::butterfly(g_fft);
                    std::vector<mint> h(2 * _n);
                    for (int i = 0; i < 2 * _n; ++i) {
                        h[i] = f_fft[i] * g_fft[i];
                    }
                    atcoder::internal::butterfly_inv(h);
                    ensure(2 * _n);
                    const mint z = mint(2 * _n).inv();
                    for (int i = _n - 1; i < 2 * _n; ++i) {
                        _h[i] += h[i] * z;
                    }
                    _f_fft.push_back(std::move(f_fft));
                    _g_fft.push_back(std::move(g_fft));
                } else {
                    const int log_q = __builtin_ctz(q);
                    const int k = (-q & q) << threshold_log;

                    std::vector<mint> f_fft(_f.end() - k, _f.end());
                    std::vector<mint> g_fft(_g.end() - k, _g.end());
                    f_fft.resize(2 * k);
                    g_fft.resize(2 * k);
                    atcoder::internal::butterfly(f_fft);
                    atcoder::internal::butterfly(g_fft);
                    std::vector<mint> h(2 * k);
                    for (int i = 0; i < 2 * k; ++i) {
                        h[i] = _f_fft[log_q + 1][i] * g_fft[i] + f_fft[i] * _g_fft[log_q + 1][i];
                    }
                    atcoder::internal::butterfly_inv(h);
                    const mint z = mint(2 * k).inv();
                    for (int i = 0; i < k; ++i) {
                        _h[_n - 1 + i] += h[k - 1 + i] * z;
                    }
                }
            } else {
                // naive convolve
                ensure(_n);
                for (int i = 0; i < r; ++i) {
                    _h[_n - 1] += _f[i] * _g[_n - 1 - i];
                }
                if (_n != r) {
                    for (int i = 0; i < r; ++i) {
                        _h[_n - 1] += _f[_n - i - 1] * _g[i];
                    }
                }
            }
            return _h[_n - 1];
        }

        const mint& operator[](int i) const {
            return _h[i];
        }
        std::vector<mint> get() const {
            return _h;
        }

    private:
        int _n;
        std::vector<mint> _f, _g, _h;

        std::vector<std::vector<mint>> _f_fft, _g_fft;

        void ensure(std::size_t n) {
            if (_h.size() < n) _h.resize(n);
        }
    };
} // namespace suisen



#line 21 "test/src/convolution/relaxed_convolution_ntt/convolution_mod.test.cpp"

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::size_t n, m;
    std::cin >> n >> m;

    std::vector<mint> a(n), b(m);
    for (auto& e : a) std::cin >> e;
    for (auto& e : b) std::cin >> e;

    suisen::RelaxedConvolutionNTT<mint> conv;

    for (std::size_t i = 0; i < n + m - 1; ++i) {
        conv.append(i < a.size() ? a[i] : 0, i < b.size() ? b[i] : 0);
    }
    auto c = conv.get();
    for (std::size_t i = 0; i < n + m - 1; ++i) {
        std::cout << c[i] << " \n"[i == n + m - 2];
    }
    return 0;
}
Back to top page