cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: Deterministic Miller Rabin
(library/number/deterministic_miller_rabin.hpp)

Deterministic Miller Rabin

Depends on

Required by

Verified with

Code

#ifndef SUISEN_DETERMINISTIC_MILLER_RABIN
#define SUISEN_DETERMINISTIC_MILLER_RABIN

#include <array>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <tuple>
#include <type_traits>

#include "library/number/montogomery.hpp"

namespace suisen::miller_rabin {
    namespace internal {
        constexpr uint64_t THRESHOLD_1 = 341531ULL;
        constexpr uint64_t BASE_1[]{ 9345883071009581737ULL };

        constexpr uint64_t THRESHOLD_2 = 1050535501ULL;
        constexpr uint64_t BASE_2[]{ 336781006125ULL, 9639812373923155ULL };

        constexpr uint64_t THRESHOLD_3 = 350269456337ULL;
        constexpr uint64_t BASE_3[]{ 4230279247111683200ULL, 14694767155120705706ULL, 16641139526367750375ULL };

        constexpr uint64_t THRESHOLD_4 = 55245642489451ULL;
        constexpr uint64_t BASE_4[]{ 2ULL, 141889084524735ULL, 1199124725622454117ULL, 11096072698276303650ULL };

        constexpr uint64_t THRESHOLD_5 = 7999252175582851ULL;
        constexpr uint64_t BASE_5[]{ 2ULL, 4130806001517ULL, 149795463772692060ULL, 186635894390467037ULL, 3967304179347715805ULL };

        constexpr uint64_t THRESHOLD_6 = 585226005592931977ULL;
        constexpr uint64_t BASE_6[]{ 2ULL, 123635709730000ULL, 9233062284813009ULL, 43835965440333360ULL, 761179012939631437ULL, 1263739024124850375ULL };

        constexpr uint64_t BASE_7[]{ 2U, 325U, 9375U, 28178U, 450775U, 9780504U, 1795265022U };

        template <auto BASE, std::size_t SIZE>
        constexpr bool miller_rabin(uint64_t n) {
            if (n == 2 or n == 3 or n == 5 or n == 7) return true;
            if (n <= 1 or n % 2 == 0 or n % 3 == 0 or n % 5 == 0 or n % 7 == 0) return false;
            if (n < 121) return true;

            const uint32_t s = __builtin_ctzll(n - 1); // >= 1
            const uint64_t d = (n - 1) >> s;

            const Montgomery64 mg{ n };

            const uint64_t one = mg.make(1), minus_one = mg.make(n - 1);

            for (std::size_t i = 0; i < SIZE; ++i) {
                uint64_t a = BASE[i] % n;
                if (a == 0) continue;
                uint64_t Y = mg.pow(mg.make(a), d);
                if (Y == one) continue;
                for (uint32_t r = 0;; ++r, Y = mg.mul(Y, Y)) {
                    // Y = a^(d 2^r)
                    if (Y == minus_one) break;
                    if (r == s - 1) return false;
                }
            }
            return true;
        }
    }

    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    constexpr bool is_prime(T n) {
        if constexpr (std::is_signed_v<T>) {
            assert(n >= 0);
        }
        const std::make_unsigned_t<T> n_unsigned = n;
        assert(n_unsigned <= std::numeric_limits<uint64_t>::max()); // n < 2^64
        using namespace internal;
        if (n_unsigned < THRESHOLD_1) return miller_rabin<BASE_1, 1>(n_unsigned);
        if (n_unsigned < THRESHOLD_2) return miller_rabin<BASE_2, 2>(n_unsigned);
        if (n_unsigned < THRESHOLD_3) return miller_rabin<BASE_3, 3>(n_unsigned);
        if (n_unsigned < THRESHOLD_4) return miller_rabin<BASE_4, 4>(n_unsigned);
        if (n_unsigned < THRESHOLD_5) return miller_rabin<BASE_5, 5>(n_unsigned);
        if (n_unsigned < THRESHOLD_6) return miller_rabin<BASE_6, 6>(n_unsigned);
        return miller_rabin<BASE_7, 7>(n_unsigned);
    }
} // namespace suisen::miller_rabin

#endif // SUISEN_DETERMINISTIC_MILLER_RABIN
#line 1 "library/number/deterministic_miller_rabin.hpp"



#include <array>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <tuple>
#include <type_traits>

#line 1 "library/number/montogomery.hpp"



#line 6 "library/number/montogomery.hpp"
#include <limits>

namespace suisen {
    namespace internal::montgomery {
        template <typename Int, typename DInt>
        struct Montgomery {
        private:
            static constexpr uint32_t bits = std::numeric_limits<Int>::digits;
            static constexpr Int mask = ~Int(0);
            // R = 2**32 or 2**64

            // 1. N is an odd number
            // 2. N < R
            // 3. gcd(N, R) = 1
            // 4. R * R2 - N * N2 = 1
            // 5. 0 < R2 < N
            // 6. 0 < N2 < R
            Int N, N2, R2;

            // RR = R * R (mod N)
            Int RR;
        public:
            constexpr Montgomery() = default;
            explicit constexpr Montgomery(Int N) : N(N), N2(calcN2(N)), R2(calcR2(N, N2)), RR(calcRR(N)) {
                assert(N & 1);
            }

            // @returns t * R (mod N)
            constexpr Int make(Int t) const {
                return reduce(static_cast<DInt>(t) * RR);
            }
            // @returns T * R^(-1) (mod N)
            constexpr Int reduce(DInt T) const {
                // 0 <= T < RN

                // Note:
                //  1. m = T * N2 (mod R)
                //  2. 0 <= m < R
                DInt m = modR(static_cast<DInt>(modR(T)) * N2);

                // Note:
                //  T + m * N = T + T * N * N2 = T + T * (R * R2 - 1) = 0 (mod R)
                //  => (T + m * N) / R is an integer.
                //  => t * R = T + m * N = T (mod N)
                //  => t = T R^(-1) (mod N)
                DInt t = divR(T + m * N);

                // Note:
                //  1. 0 <= T < RN
                //  2. 0 <= mN < RN (because 0 <= m < R)
                //  => 0 <= T + mN < 2RN
                //  => 0 <= t < 2N
                return t >= N ? t - N : t;
            }

            constexpr Int add(Int A, Int B) const {
                return (A += B) >= N ? A - N : A;
            }
            constexpr Int sub(Int A, Int B) const {
                return (A -= B) < 0 ? A + N : A;
            }
            constexpr Int mul(Int A, Int B) const {
                return reduce(static_cast<DInt>(A) * B);
            }
            constexpr Int div(Int A, Int B) const {
                return reduce(static_cast<DInt>(A) * inv(B));
            }
            constexpr Int inv(Int A) const; // TODO: Implement

            constexpr Int pow(Int A, long long b) const {
                Int P = make(1);
                for (; b; b >>= 1) {
                    if (b & 1) P = mul(P, A);
                    A = mul(A, A);
                }
                return P;
            }

        private:
            static constexpr Int divR(DInt t) { return t >> bits; }
            static constexpr Int modR(DInt t) { return t & mask; }

            static constexpr Int calcN2(Int N) {
                // - N * N2 = 1 (mod R)
                // N2 = -N^{-1} (mod R)

                // calculates N^{-1} (mod R) by Newton's method
                DInt invN = N; // = N^{-1} (mod 2^2)
                for (uint32_t cur_bits = 2; cur_bits < bits; cur_bits *= 2) {
                    // loop invariant: invN = N^{-1} (mod 2^cur_bits)

                    // x = a^{-1} mod m => x(2-ax) = a^{-1} mod m^2 because:
                    //  ax = 1 (mod m)
                    //  => (ax-1)^2 = 0 (mod m^2)
                    //  => 2ax - a^2x^2 = 1 (mod m^2)
                    //  => a(x(2-ax)) = 1 (mod m^2)
                    invN = modR(invN * modR(2 - N * invN));
                }
                assert(modR(N * invN) == 1);

                return modR(-invN);
            }
            static constexpr Int calcR2(Int N, Int N2) {
                // R * R2 - N * N2 = 1
                // => R2 = (1 + N * N2) / R
                return divR(1 + static_cast<DInt>(N) * N2);
            }
            static constexpr Int calcRR(Int N) {
                return -DInt(N) % N;
            }
        };
    } // namespace internal::montgomery
    using Montgomery32 = internal::montgomery::Montgomery<uint32_t, uint64_t>;
    using Montgomery64 = internal::montgomery::Montgomery<uint64_t, __uint128_t>;
} // namespace suisen



#line 12 "library/number/deterministic_miller_rabin.hpp"

namespace suisen::miller_rabin {
    namespace internal {
        constexpr uint64_t THRESHOLD_1 = 341531ULL;
        constexpr uint64_t BASE_1[]{ 9345883071009581737ULL };

        constexpr uint64_t THRESHOLD_2 = 1050535501ULL;
        constexpr uint64_t BASE_2[]{ 336781006125ULL, 9639812373923155ULL };

        constexpr uint64_t THRESHOLD_3 = 350269456337ULL;
        constexpr uint64_t BASE_3[]{ 4230279247111683200ULL, 14694767155120705706ULL, 16641139526367750375ULL };

        constexpr uint64_t THRESHOLD_4 = 55245642489451ULL;
        constexpr uint64_t BASE_4[]{ 2ULL, 141889084524735ULL, 1199124725622454117ULL, 11096072698276303650ULL };

        constexpr uint64_t THRESHOLD_5 = 7999252175582851ULL;
        constexpr uint64_t BASE_5[]{ 2ULL, 4130806001517ULL, 149795463772692060ULL, 186635894390467037ULL, 3967304179347715805ULL };

        constexpr uint64_t THRESHOLD_6 = 585226005592931977ULL;
        constexpr uint64_t BASE_6[]{ 2ULL, 123635709730000ULL, 9233062284813009ULL, 43835965440333360ULL, 761179012939631437ULL, 1263739024124850375ULL };

        constexpr uint64_t BASE_7[]{ 2U, 325U, 9375U, 28178U, 450775U, 9780504U, 1795265022U };

        template <auto BASE, std::size_t SIZE>
        constexpr bool miller_rabin(uint64_t n) {
            if (n == 2 or n == 3 or n == 5 or n == 7) return true;
            if (n <= 1 or n % 2 == 0 or n % 3 == 0 or n % 5 == 0 or n % 7 == 0) return false;
            if (n < 121) return true;

            const uint32_t s = __builtin_ctzll(n - 1); // >= 1
            const uint64_t d = (n - 1) >> s;

            const Montgomery64 mg{ n };

            const uint64_t one = mg.make(1), minus_one = mg.make(n - 1);

            for (std::size_t i = 0; i < SIZE; ++i) {
                uint64_t a = BASE[i] % n;
                if (a == 0) continue;
                uint64_t Y = mg.pow(mg.make(a), d);
                if (Y == one) continue;
                for (uint32_t r = 0;; ++r, Y = mg.mul(Y, Y)) {
                    // Y = a^(d 2^r)
                    if (Y == minus_one) break;
                    if (r == s - 1) return false;
                }
            }
            return true;
        }
    }

    template <typename T, std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
    constexpr bool is_prime(T n) {
        if constexpr (std::is_signed_v<T>) {
            assert(n >= 0);
        }
        const std::make_unsigned_t<T> n_unsigned = n;
        assert(n_unsigned <= std::numeric_limits<uint64_t>::max()); // n < 2^64
        using namespace internal;
        if (n_unsigned < THRESHOLD_1) return miller_rabin<BASE_1, 1>(n_unsigned);
        if (n_unsigned < THRESHOLD_2) return miller_rabin<BASE_2, 2>(n_unsigned);
        if (n_unsigned < THRESHOLD_3) return miller_rabin<BASE_3, 3>(n_unsigned);
        if (n_unsigned < THRESHOLD_4) return miller_rabin<BASE_4, 4>(n_unsigned);
        if (n_unsigned < THRESHOLD_5) return miller_rabin<BASE_5, 5>(n_unsigned);
        if (n_unsigned < THRESHOLD_6) return miller_rabin<BASE_6, 6>(n_unsigned);
        return miller_rabin<BASE_7, 7>(n_unsigned);
    }
} // namespace suisen::miller_rabin
Back to top page