cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: test/src/number/mod_sqrt/dummy.test.cpp

Depends on

Code

#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A"

#include <iostream>
#include <random>

#include "library/number/util.hpp"
#include "library/number/mod_sqrt.hpp"

void test_small() {
    for (int m = 1; m <= 700; ++m) {
        for (int a = 0; a < m; ++a) {
            auto x = suisen::composite_mod_sqrt(a, suisen::factorize(m));
            if (x) {
                int x0 = *x;
                assert(x0 * x0 % m == a);
            } else {
                for (int b = 0; b < m; ++b) {
                    assert(b * b % m != a);
                }
            }
        }
    }
}

void test_large() {
    std::mt19937 rng{ 0 };
    std::uniform_int_distribution<long long> dist_m(1, 1000000000000);
    
    for (int q = 0; q < 100; ++q) {
        long long m = dist_m(rng);
        std::uniform_int_distribution<long long> dist_a(0, m - 1);
        auto factorized = suisen::factorize(m);

        for (int inner_q = 0; inner_q < 10000; ++inner_q) {
            long long a = dist_a(rng);

            auto x = suisen::composite_mod_sqrt(a, factorized);
            if (x) {
                __int128_t x0 = *x;
                assert(x0 * x0 % m == a);
            }
        }
    }
}

void test() {
    test_small();
    test_large();
}

int main() {
    test();
    std::cout << "Hello World" << std::endl;
    return 0;
}
#line 1 "test/src/number/mod_sqrt/dummy.test.cpp"
#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A"

#include <iostream>
#include <random>

#line 1 "library/number/util.hpp"



#include <array>
#include <cassert>
#include <cmath>
#include <numeric>
#include <tuple>
#include <vector>

/**
 * @brief Utilities
*/

namespace suisen {
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    T powi(T a, int b) {
        T res = 1, pow_a = a;
        for (; b; b >>= 1) {
            if (b & 1) res *= pow_a;
            pow_a *= pow_a;
        }
        return res;
    }

    /**
     * @brief Calculates the prime factorization of n in O(√n).
     * @tparam T integer type
     * @param n integer to factorize
     * @return vector of { prime, exponent }. It is guaranteed that prime is ascending.
     */
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    std::vector<std::pair<T, int>> factorize(T n) {
        static constexpr std::array primes{ 2, 3, 5, 7, 11, 13 };
        static constexpr int next_prime = 17;
        static constexpr int siz = std::array{ 1, 2, 8, 48, 480, 5760, 92160 } [primes.size() - 1] ;
        static constexpr int period = [] {
            int res = 1;
            for (auto e : primes) res *= e;
            return res;
        }();
        static constexpr struct S : public std::array<int, siz> {
            constexpr S() {
                for (int i = next_prime, j = 0; i < period + next_prime; i += 2) {
                    bool ok = true;
                    for (int p : primes) ok &= i % p > 0;
                    if (ok) (*this)[j++] = i - next_prime;
                }
            }
        } s{};

        assert(n > 0);
        std::vector<std::pair<T, int>> res;
        auto f = [&res, &n](int p) {
            if (n % p) return;
            int cnt = 0;
            do n /= p, ++cnt; while (n % p == 0);
            res.emplace_back(p, cnt);
        };
        for (int p : primes) f(p);
        for (T b = next_prime; b * b <= n; b += period) {
            for (int offset : s) f(b + offset);
        }
        if (n != 1) res.emplace_back(n, 1);
        return res;
    }

    /**
     * @brief Enumerates divisors of n from its prime-factorized form in O(# of divisors of n) time.
     * @tparam T integer type
     * @param factorized a prime-factorized form of n (a vector of { prime, exponent })
     * @return vector of divisors (NOT sorted)
     */
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    std::vector<T> divisors(const std::vector<std::pair<T, int>>& factorized) {
        std::vector<T> res{ 1 };
        for (auto [p, c] : factorized) {
            for (int i = 0, sz = res.size(); i < sz; ++i) {
                T d = res[i];
                for (int j = 0; j < c; ++j) res.push_back(d *= p);
            }
        }
        return res;
    }
    /**
     * @brief Enumerates divisors of n in O(√n) time.
     * @tparam T integer type
     * @param n
     * @return vector of divisors (NOT sorted)
     */
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    std::vector<T> divisors(T n) {
        return divisors(factorize(n));
    }
    /**
     * @brief Calculates the divisors for i=1,...,n in O(n log n) time.
     * @param n upper bound (closed)
     * @return 2-dim vector a of length n+1, where a[i] is the vector of divisors of i.
     */
    std::vector<std::vector<int>> divisors_table(int n) {
        std::vector<std::vector<int>> divs(n + 1);
        for (int i = 1; i <= n; ++i) {
            for (int j = i; j <= n; ++j) divs[j].push_back(i);
        }
        return divs;
    }

    /**
     * @brief Calculates φ(n) from its prime-factorized form in O(log n).
     * @tparam T integer type
     * @param factorized a prime-factorized form of n (a vector of { prime, exponent })
     * @return φ(n)
     */
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    T totient(const std::vector<std::pair<T, int>>& factorized) {
        T res = 1;
        for (const auto& [p, c] : factorized) res *= (p - 1) * powi(p, c - 1);
        return res;
    }
    /**
     * @brief Calculates φ(n) in O(√n).
     * @tparam T integer type
     * @param n
     * @return φ(n)
     */
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    T totient(T n) {
        return totient(factorize(n));
    }
    /**
     * @brief Calculates φ(i) for i=1,...,n.
     * @param n upper bound (closed)
     * @return vector a of length n+1, where a[i]=φ(i) for i=1,...,n
     */
    std::vector<int> totient_table(int n) {
        std::vector<int> res(n + 1);
        for (int i = 0; i <= n; ++i) res[i] = (i & 1) == 0 ? i >> 1 : i;
        for (int p = 3; p * p <= n; p += 2) {
            if (res[p] != p) continue;
            for (int q = p; q <= n; q += p) res[q] /= p, res[q] *= p - 1;
        }
        return res;
    }

    /**
     * @brief Calculates λ(n) from its prime-factorized form in O(log n).
     * @tparam T integer type
     * @param factorized a prime-factorized form of n (a vector of { prime, exponent })
     * @return λ(n)
     */
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    T carmichael(const std::vector<std::pair<T, int>>& factorized) {
        T res = 1;
        for (const auto &[p, c] : factorized) {
            res = std::lcm(res, ((p - 1) * powi(p, c - 1)) >> (p == 2 and c >= 3));
        }
        return res;
    }
    /**
     * @brief Calculates λ(n) in O(√n).
     * @tparam T integer type
     * @param n
     * @return λ(n)
     */
    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    T carmichael(T n) {
        return carmichael(factorize(n));
    }
} // namespace suisen


#line 1 "library/number/mod_sqrt.hpp"



#include <optional>
#include <atcoder/math>

namespace suisen {
    namespace internal {
        long long inv_mod64(long long a, long long m) {
            return atcoder::inv_mod(a, m);
        }
        long long pow_mod64(long long a, long long b, long long m) {
            if ((a %= m) < 0) a += m;
            long long res = 1, pow_a = a;
            for (; b; b >>= 1) {
                if (b & 1) {
                    res = __int128_t(res) * pow_a % m;
                }
                pow_a = __int128_t(pow_a) * pow_a % m;
            }
            return res;
        }
        long long mul_mod64(long long a, long long b, long long m) {
            return __int128_t(a) * b % m;
        }
    }

    std::optional<long long> prime_mod_sqrt(long long a, const long long p) {
        using namespace internal;

        if ((a %= p) < 0) a += p;

        if (a == 0) return 0;
        if (p == 2) return a;

        if (pow_mod64(a, (p - 1) / 2, p) != 1) {
            return std::nullopt;
        }

        long long b = 1;
        while (pow_mod64(b, (p - 1) / 2, p) == 1) {
            ++b;
        }

        int tlz = __builtin_ctz(p - 1);
        long long q = (p - 1) >> tlz;

        long long ia = inv_mod64(a, p);

        long long x = pow_mod64(a, (q + 1) / 2, p);
        b = pow_mod64(b, q, p);
        for (int shift = 2;; ++shift) {
            long long x2 = mul_mod64(x, x, p);
            if (x2 == a) {
                return x;
            }
            long long e = mul_mod64(ia, x2, p);
            if (pow_mod64(e, 1 << (tlz - shift), p) != 1) {
                x = mul_mod64(x, b, p);
            }
            b = mul_mod64(b, b, p);
        }
    }

    namespace internal {
        std::optional<long long> prime_power_mod_sqrt(long long a, long long p, int q) {
            std::vector<long long> pq(q + 1);
            pq[0] = 1;
            for (int i = 1; i <= q; ++i) {
                pq[i] = pq[i - 1] * p;
            }
            if ((a %= pq[q]) == 0) return 0;

            int b = 0;
            for (; a % p == 0; a /= p) {
                ++b;
            }
            if (b % 2) {
                return std::nullopt;
            }
            const long long c = pq[b / 2];

            q -= b;

            if (p != 2) {
                // reference: http://aozoragakuen.sakura.ne.jp/suuron/node24.html
                // f(x) = x^2 - a, f'(x) = 2x
                // Lifting from f(x_i)=0 mod p^i to f(x_{i+1})=0 mod p^{i+1}
                auto ox = prime_mod_sqrt(a, p);
                if (not ox) {
                    return std::nullopt;
                }
                long long x = *ox;
                // f'(x_i) != 0
                const long long inv_df_x0 = inv_mod64(2 * x, p);
                for (int i = 1; i < q; ++i) {
                    // Requirements:
                    //      x_{i+1} = x_i + p^i * y for some 0 <= y < p.
                    // Taylor expansion:
                    //      f(x_i + p^i y) = f(x_i) + y p^i f'(x_i) + p^{i+1} * (...)
                    // f(x_i) = 0 (mod p^i) and f'(x_i) = f'(x_0) != 0 (mod p), so
                    //      y = -(f(x_i)/p^i) * f'(x_0)^(-1) (mod p)
                    __int128_t f_x = __int128_t(x) * x - a;
                    long long y = mul_mod64(-(f_x / pq[i]) % p, inv_df_x0, p);
                    if (y < 0) y += p;
                    x += pq[i] * y;
                }
                return c * x;
            } else {
                // p = 2
                if (a % 8 != 1) {
                    return std::nullopt;
                }
                // reference: https://twitter.com/maspy_stars/status/1613931151718244352?s=20&t=lAf7ztW2fb_IZa544lo2xw
                long long x = 1; // or 3
                for (int i = 3; i < q; ++i) {
                    // Requirements:
                    //      x_{i+1} = x_i + 2^{i-1} y for some 0 <= y < 2.
                    // x_i is an odd number, so
                    //      (x_i + 2^{i-1} y)^2 = x_i^2 + y 2^i (mod 2^{i+1}).
                    // Therefore,
                    //      y = (a - x_i^2)/2^i (mod 2).
                    __int128_t f_x = __int128_t(x) * x - a;
                    x |= ((f_x >> i) & 1) << (i - 1);
                }
                return c * x;
            }
        }
    }

    template <typename PrimePowers>
    std::optional<long long> composite_mod_sqrt(long long a, const PrimePowers& factorized) {
        std::vector<long long> rs, ms;
        for (auto [p, q] : factorized) {
            auto x = internal::prime_power_mod_sqrt(a, p, q);
            if (not x) {
                return std::nullopt;
            }
            rs.push_back(*x);
            long long& pq = ms.emplace_back(1);
            for (int i = 0; i < q; ++i) pq *= p;
        }
        return atcoder::crt(rs, ms).first;
    }
} // namespace suisen



#line 8 "test/src/number/mod_sqrt/dummy.test.cpp"

void test_small() {
    for (int m = 1; m <= 700; ++m) {
        for (int a = 0; a < m; ++a) {
            auto x = suisen::composite_mod_sqrt(a, suisen::factorize(m));
            if (x) {
                int x0 = *x;
                assert(x0 * x0 % m == a);
            } else {
                for (int b = 0; b < m; ++b) {
                    assert(b * b % m != a);
                }
            }
        }
    }
}

void test_large() {
    std::mt19937 rng{ 0 };
    std::uniform_int_distribution<long long> dist_m(1, 1000000000000);
    
    for (int q = 0; q < 100; ++q) {
        long long m = dist_m(rng);
        std::uniform_int_distribution<long long> dist_a(0, m - 1);
        auto factorized = suisen::factorize(m);

        for (int inner_q = 0; inner_q < 10000; ++inner_q) {
            long long a = dist_a(rng);

            auto x = suisen::composite_mod_sqrt(a, factorized);
            if (x) {
                __int128_t x0 = *x;
                assert(x0 * x0 % m == a);
            }
        }
    }
}

void test() {
    test_small();
    test_large();
}

int main() {
    test();
    std::cout << "Hello World" << std::endl;
    return 0;
}
Back to top page