FFT-free な形式的べき級数
(library/polynomial/fps_naive.hpp)
FFT-free な形式的べき級数
積や乗法逆元 $\mathrm{inv}$,指数関数 $\exp$,対数関数 $\log$,べき乗 $\mathrm{pow}$,平方根 $\mathrm{sqrt}$ を $\Theta(N ^ 2)$ で計算する形式的べき級数ライブラリ.Set Power Series など,最大次数が十分小さいことが分かっている場合にこちらを用いることで高速化が期待できる.
FFT や NTT を必要としない点も特徴的?
前置き
形式的べき級数 $f$ の $i$ 次の係数を $f_i$ と書く.
積
$h = fg$ を満たす $h$ を求める.これは,以下のように計算される.
\[h _ n = \sum _ {i = 0} ^ n f _ i g _ {n - i}.\]
FPS としての除算
$h = \dfrac{f}{g}$ を満たす $h$ を求める.ただし,$g _ 0 \neq 0$ を仮定する.$f = gh$ であるから,次のように計算される.
\[\begin{aligned}
f _ n
&= \sum _ {i = 0} ^ n h _ i g _ {n - i}\\
&= h _ n g _ 0 + \sum _ {i = 0} ^ {n - 1} h _ i g _ {n - i},\\
h _ n
&= \dfrac{1}{g _ 0}\left(f _ n - \sum _ {i = 0} ^ {n - 1}h _ i g _ {n - i} \right).
\end{aligned}\]
$\mathrm{inv}$
$gh=1$ を満たす $h$ を求める.ただし,$g _ 0 \neq 0$ を仮定する.形式的べき級数としての除算において $f\equiv 1$ とすればよいので,次のように計算される.
\[h _ n = \begin{cases}
\dfrac{1}{g _ 0} & \text{if } n = 0\\
\displaystyle -\dfrac{1}{g _ 0} \sum _ {i = 0} ^ {n - 1} h _ i g _ {n - i} & \text{otherwise}
\end{cases}.\]
$\exp$
$g = \exp f$ を満たす $g$ を求める.ただし,$f _ 0 = 0$ を仮定する.$g’ = f’ g$ より,次のように計算される.
\[\begin{aligned}
(n + 1) g _ {n + 1}
&= \sum _ {i = 0} ^ n (n - i + 1) f _ {n - i + 1} g _ i,\\
g _ n &= \begin{cases}
1 & \text{if }n = 0 \\
\dfrac{1}{n}\displaystyle\sum _ {i = 0} ^ {n - 1} (n - i) f _ {n - i} g _ i &\text{otherwise}
\end{cases}
\end{aligned}\]
$\log$
$g = \log f$ を満たす $g$ を求める.ただし,$f _ 0 = 1$ を仮定する.$fg’=f’$ より,次のように計算される.
\[\begin{aligned}
\sum _ {i = 0} ^ n (i + 1) g _ {i + 1} f _ {n - i} &= (n + 1)f _ {n + 1},\\
(n + 1) f _ 0 g _ {n + 1} &= (n + 1) f _ {n + 1} - \sum _ {i = 0} ^ {n - 1}(i + 1) g _ {i + 1} f _ {n - i},\\
g _ {n + 1} &= \dfrac{1}{(n + 1) f _ 0}\left((n + 1) f _ {n + 1} - \sum _ {i = 0} ^ {n - 1}(i + 1) g _ {i + 1} f _ {n - i}\right),\\
g _ n &= \begin{cases}
0 & \text{if }n = 0\\
\displaystyle \dfrac{1}{nf_0}\left(n f _ {n} - \sum _ {i = 1} ^ {n - 1}i g _ i f _ {n - i}\right) & \text{otherwise}
\end{cases}
\end{aligned}\]
$\mathrm{pow}$
$g = f ^ k$ を満たす $g$ を求める.ただし,$k \gt 0$ および $f_0\neq 0$ を仮定する.$g’=kf^{k-1} f’$ の両辺に $f$ を掛けることで $fg’=kgf’$ が得られるので,以下のように計算される.
\[\begin{aligned}
\sum _ {i = 0} ^ n (i + 1) g _ {i + 1} f _ {n - i}
&= k\sum _ {i = 0} ^ n g _ i\cdot (n - i + 1) f _ {n - i + 1}, \\
(n+1)f_0g _ {n+1}
&= k\sum _ {i = 0} ^ n g _ i\cdot (n - i + 1) f _ {n - i + 1} - \sum _ {i = 1} ^ {n} i g _ i f _ {n - i + 1} \\
&=\sum _ {i = 0} ^ {n} (k(n-i+1)-i)g_i f _{n-i+1},\\
g _ {n + 1} &= \dfrac{1}{(n + 1) f _ 0}\sum _ {i = 0} ^ {n} (k (n - i + 1) - i) g _ i f _ {n - i + 1},\\
g _ n &= \begin{cases}
f _ 0 ^ k & \text{if }n = 0\\
\displaystyle\dfrac{1}{nf_0}\sum _ {i = 0} ^ {n - 1} (k (n - i) - i) g _ i f _ {n - i} & \text{otherwise}
\end{cases}
\end{aligned}\]
$k=0$ の場合は,$g\equiv 1$ として計算される.
$k\gt 0$ かつ $f _ 0 = 0$ の場合は,$f = x ^ p \cdot f’\ (f’_0 \neq 0)$ の形に直してから $g = x ^ {pk} f’^ k$ として計算すればよい.そのような $f’$ が存在しない場合,即ち $f\equiv 0$ の場合は $g \equiv 0$ である.
$\mathrm{sqrt}$
$g ^ 2 = f$ を満たす $g$ を $1$ つ求める (存在しても一意とは限らない).ただし,$g _ 0 \neq 0$ および $v ^ 2 = f _ 0$ を満たす $v$ (これを $f _ 0 ^ {1 / 2}$ と書く) が存在することを仮定する.これは,以下のように計算できる.
\[\begin{aligned}
f _ {n + 1} &= \sum _ {i = 0} ^ {n + 1} g _ i g _ {n + 1 - i},\\
2 g _ 0 g _ {n + 1} &= f _ {n + 1} - \sum _ {i = 1} ^ {n} g _ i g _ {n + 1 - i},\\
g _ n &= \begin{cases}
f _ 0 ^ {1/2} & \text{if }n = 0\\
\displaystyle \dfrac{1}{2 g _ 0}\left(f _ {n } - \sum _ {i = 1} ^ {n - 1} g _ i g _ {n - i}\right) & \text{otherwise}
\end{cases}.
\end{aligned}\]
$f _ i \neq 0$ となる最小の $i$ が偶数であれば,上式で計算される $g$ に対して $g’ = x ^ {i / 2} g$ が ${g’} ^ 2 = f$ を満たす.$i$ が奇数であれば,条件を満たす $g$ は存在しない.
$k$ th root (未実装 )
$g ^ k = f$ を満たす $g$ を求める.$\mathrm{pow}$ と同様の計算により,次を得る.
\[g _ n = \begin{cases}
f _ 0 ^ {1/k} & \text{if }n = 0\\
\displaystyle\dfrac{1}{nf_0}\sum _ {i = 0} ^ {n-1} (k ^ {- 1} (n - i) - i) g _ i f _ {n - i} & \text{otherwise}
\end{cases}\]
$\mathrm{pow}$ と組み合わせることで,より一般に $g = f ^ {p/q}$ を満たす $g$ を計算できる.
Depends on
Required by
Verified with
Code
#ifndef SUISEN_FPS_NAIVE_OPS
#define SUISEN_FPS_NAIVE_OPS
#include <cassert>
#include <cmath>
#include <limits>
#include <type_traits>
#include <vector>
#include "library/type_traits/type_traits.hpp"
#include "library/math/modint_extension.hpp"
#include "library/math/inv_mods.hpp"
namespace suisen {
template < typename T >
struct FPSNaive : std :: vector < T > {
static inline int MAX_SIZE = std :: numeric_limits < int >:: max () / 2 ;
using value_type = T ;
using element_type = rec_value_type_t < T > ;
using std :: vector < value_type >:: vector ;
FPSNaive ( const std :: initializer_list < value_type > l ) : std :: vector < value_type >:: vector ( l ) {}
FPSNaive ( const std :: vector < value_type >& v ) : std :: vector < value_type >:: vector ( v ) {}
static void set_max_size ( int n ) {
FPSNaive < T >:: MAX_SIZE = n ;
}
const value_type operator []( int n ) const {
return n <= deg () ? unsafe_get ( n ) : value_type { 0 };
}
value_type & operator []( int n ) {
return ensure_deg ( n ), unsafe_get ( n );
}
int size () const {
return std :: vector < value_type >:: size ();
}
int deg () const {
return size () - 1 ;
}
int normalize () {
while ( size () and this -> back () == value_type { 0 }) this -> pop_back ();
return deg ();
}
FPSNaive & cut_inplace ( int n ) {
if ( size () > n ) this -> resize ( std :: max ( 0 , n ));
return * this ;
}
FPSNaive cut ( int n ) const {
FPSNaive f = FPSNaive ( * this ). cut_inplace ( n );
return f ;
}
FPSNaive operator + () const {
return FPSNaive ( * this );
}
FPSNaive operator - () const {
FPSNaive f ( * this );
for ( auto & e : f ) e = - e ;
return f ;
}
FPSNaive & operator ++ () { return ++ ( * this )[ 0 ], * this ; }
FPSNaive & operator -- () { return -- ( * this )[ 0 ], * this ; }
FPSNaive & operator += ( const value_type x ) { return ( * this )[ 0 ] += x , * this ; }
FPSNaive & operator -= ( const value_type x ) { return ( * this )[ 0 ] -= x , * this ; }
FPSNaive & operator += ( const FPSNaive & g ) {
ensure_deg ( g . deg ());
for ( int i = 0 ; i <= g . deg (); ++ i ) unsafe_get ( i ) += g . unsafe_get ( i );
return * this ;
}
FPSNaive & operator -= ( const FPSNaive & g ) {
ensure_deg ( g . deg ());
for ( int i = 0 ; i <= g . deg (); ++ i ) unsafe_get ( i ) -= g . unsafe_get ( i );
return * this ;
}
FPSNaive & operator *= ( const FPSNaive & g ) { return * this = * this * g ; }
FPSNaive & operator *= ( const value_type x ) {
for ( auto & e : * this ) e *= x ;
return * this ;
}
FPSNaive & operator /= ( const FPSNaive & g ) { return * this = * this / g ; }
FPSNaive & operator %= ( const FPSNaive & g ) { return * this = * this % g ; }
FPSNaive & operator <<= ( const int shamt ) {
this -> insert ( this -> begin (), shamt , value_type { 0 });
return * this ;
}
FPSNaive & operator >>= ( const int shamt ) {
if ( shamt > size ()) this -> clear ();
else this -> erase ( this -> begin (), this -> begin () + shamt );
return * this ;
}
friend FPSNaive operator + ( FPSNaive f , const FPSNaive & g ) { f += g ; return f ; }
friend FPSNaive operator + ( FPSNaive f , const value_type & x ) { f += x ; return f ; }
friend FPSNaive operator - ( FPSNaive f , const FPSNaive & g ) { f -= g ; return f ; }
friend FPSNaive operator - ( FPSNaive f , const value_type & x ) { f -= x ; return f ; }
friend FPSNaive operator * ( const FPSNaive & f , const FPSNaive & g ) {
if ( f . empty () or g . empty ()) return FPSNaive {};
const int n = f . size (), m = g . size ();
FPSNaive h ( std :: min ( MAX_SIZE , n + m - 1 ));
for ( int i = 0 ; i < n ; ++ i ) for ( int j = 0 ; j < m ; ++ j ) {
if ( i + j >= MAX_SIZE ) break ;
h . unsafe_get ( i + j ) += f . unsafe_get ( i ) * g . unsafe_get ( j );
}
return h ;
}
friend FPSNaive operator * ( FPSNaive f , const value_type & x ) { f *= x ; return f ; }
friend FPSNaive operator / ( FPSNaive f , const FPSNaive & g ) { return std :: move ( f . div_mod ( g ). first ); }
friend FPSNaive operator % ( FPSNaive f , const FPSNaive & g ) { return std :: move ( f . div_mod ( g ). second ); }
friend FPSNaive operator * ( const value_type x , FPSNaive f ) { f *= x ; return f ; }
friend FPSNaive operator << ( FPSNaive f , const int shamt ) { f <<= shamt ; return f ; }
friend FPSNaive operator >> ( FPSNaive f , const int shamt ) { f >>= shamt ; return f ; }
std :: pair < FPSNaive , FPSNaive > div_mod ( FPSNaive g ) const {
FPSNaive f = * this ;
const int fd = f . normalize (), gd = g . normalize ();
assert ( gd >= 0 );
if ( fd < gd ) return { FPSNaive {}, f };
if ( gd == 0 ) return { f *= g . unsafe_get ( 0 ). inv (), FPSNaive {} };
const int k = f . deg () - gd ;
value_type head_inv = g . unsafe_get ( gd ). inv ();
FPSNaive q ( k + 1 );
for ( int i = k ; i >= 0 ; -- i ) {
value_type div = f . unsafe_get ( i + gd ) * head_inv ;
q . unsafe_get ( i ) = div ;
for ( int j = 0 ; j <= gd ; ++ j ) f . unsafe_get ( i + j ) -= div * g . unsafe_get ( j );
}
f . cut_inplace ( gd );
f . normalize ();
return { q , f };
}
friend bool operator == ( const FPSNaive & f , const FPSNaive & g ) {
const int n = f . size (), m = g . size ();
if ( n < m ) return g == f ;
for ( int i = 0 ; i < m ; ++ i ) if ( f . unsafe_get ( i ) != g . unsafe_get ( i )) return false ;
for ( int i = m ; i < n ; ++ i ) if ( f . unsafe_get ( i ) != 0 ) return false ;
return true ;
}
friend bool operator != ( const FPSNaive & f , const FPSNaive & g ) {
return not ( f == g );
}
FPSNaive mul ( const FPSNaive & g , int n = - 1 ) const {
if ( n < 0 ) n = size ();
if ( this -> empty () or g . empty ()) return FPSNaive {};
const int m = size (), k = g . size ();
FPSNaive h ( std :: min ( n , m + k - 1 ));
for ( int i = 0 ; i < m ; ++ i ) {
for ( int j = 0 , jr = std :: min ( k , n - i ); j < jr ; ++ j ) {
h . unsafe_get ( i + j ) += unsafe_get ( i ) * g . unsafe_get ( j );
}
}
return h ;
}
FPSNaive diff () const {
if ( this -> empty ()) return {};
FPSNaive g ( size () - 1 );
for ( int i = 1 ; i <= deg (); ++ i ) g . unsafe_get ( i - 1 ) = unsafe_get ( i ) * i ;
return g ;
}
FPSNaive intg () const {
const int n = size ();
FPSNaive g ( n + 1 );
for ( int i = 0 ; i < n ; ++ i ) g . unsafe_get ( i + 1 ) = unsafe_get ( i ) * invs [ i + 1 ];
if ( g . deg () > MAX_SIZE ) g . cut_inplace ( MAX_SIZE );
return g ;
}
FPSNaive inv ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
FPSNaive g ( n );
const value_type inv_f0 = :: inv ( unsafe_get ( 0 ));
g . unsafe_get ( 0 ) = inv_f0 ;
for ( int i = 1 ; i < n ; ++ i ) {
for ( int j = 1 ; j <= i ; ++ j ) g . unsafe_get ( i ) -= g . unsafe_get ( i - j ) * ( * this )[ j ];
g . unsafe_get ( i ) *= inv_f0 ;
}
return g ;
}
FPSNaive exp ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
assert ( unsafe_get ( 0 ) == value_type { 0 });
FPSNaive g ( n );
g . unsafe_get ( 0 ) = value_type { 1 };
for ( int i = 1 ; i < n ; ++ i ) {
for ( int j = 1 ; j <= i ; ++ j ) g . unsafe_get ( i ) += j * g . unsafe_get ( i - j ) * ( * this )[ j ];
g . unsafe_get ( i ) *= invs [ i ];
}
return g ;
}
FPSNaive log ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
assert ( unsafe_get ( 0 ) == value_type { 1 });
FPSNaive g ( n );
g . unsafe_get ( 0 ) = value_type { 0 };
for ( int i = 1 ; i < n ; ++ i ) {
g . unsafe_get ( i ) = i * ( * this )[ i ];
for ( int j = 1 ; j < i ; ++ j ) g . unsafe_get ( i ) -= ( i - j ) * g . unsafe_get ( i - j ) * ( * this )[ j ];
g . unsafe_get ( i ) *= invs [ i ];
}
return g ;
}
FPSNaive pow ( const long long k , int n = - 1 ) const {
if ( n < 0 ) n = size ();
if ( k == 0 ) {
FPSNaive res ( n );
res [ 0 ] = 1 ;
return res ;
}
int z = 0 ;
while ( z < size () and unsafe_get ( z ) == value_type { 0 }) ++ z ;
if ( z == size () or z > ( n - 1 ) / k ) return FPSNaive ( n , 0 );
const int m = n - z * k ;
FPSNaive g ( m );
const value_type inv_f0 = :: inv ( unsafe_get ( z ));
g . unsafe_get ( 0 ) = unsafe_get ( z ). pow ( k );
for ( int i = 1 ; i < m ; ++ i ) {
for ( int j = 1 ; j <= i ; ++ j ) g . unsafe_get ( i ) += ( element_type { k } * j - ( i - j )) * g . unsafe_get ( i - j ) * ( * this )[ z + j ];
g . unsafe_get ( i ) *= inv_f0 * invs [ i ];
}
g <<= z * k ;
return g ;
}
std :: optional < FPSNaive > safe_sqrt ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
int dl = 0 ;
while ( dl < size () and unsafe_get ( dl ) == value_type { 0 }) ++ dl ;
if ( dl == size ()) return FPSNaive ( n , 0 );
if ( dl & 1 ) return std :: nullopt ;
const int m = n - dl / 2 ;
FPSNaive g ( m );
auto opt_g0 = :: safe_sqrt (( * this )[ dl ]);
if ( not opt_g0 . has_value ()) return std :: nullopt ;
g . unsafe_get ( 0 ) = * opt_g0 ;
value_type inv_2g0 = :: inv ( 2 * g . unsafe_get ( 0 ));
for ( int i = 1 ; i < m ; ++ i ) {
g . unsafe_get ( i ) = ( * this )[ dl + i ];
for ( int j = 1 ; j < i ; ++ j ) g . unsafe_get ( i ) -= g . unsafe_get ( j ) * g . unsafe_get ( i - j );
g . unsafe_get ( i ) *= inv_2g0 ;
}
g <<= dl / 2 ;
return g ;
}
FPSNaive sqrt ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
return * safe_sqrt ( n );
}
value_type eval ( value_type x ) const {
value_type y = 0 ;
for ( int i = size () - 1 ; i >= 0 ; -- i ) y = y * x + unsafe_get ( i );
return y ;
}
private:
static inline inv_mods < element_type > invs ;
void ensure_deg ( int d ) {
if ( deg () < d ) this -> resize ( d + 1 , value_type { 0 });
}
const value_type & unsafe_get ( int i ) const {
return std :: vector < value_type >:: operator []( i );
}
value_type & unsafe_get ( int i ) {
return std :: vector < value_type >:: operator []( i );
}
};
} // namespace suisen
template < typename mint >
suisen :: FPSNaive < mint > sqrt ( suisen :: FPSNaive < mint > a ) {
return a . sqrt ();
}
template < typename mint >
suisen :: FPSNaive < mint > log ( suisen :: FPSNaive < mint > a ) {
return a . log ();
}
template < typename mint >
suisen :: FPSNaive < mint > exp ( suisen :: FPSNaive < mint > a ) {
return a . exp ();
}
template < typename mint , typename T >
suisen :: FPSNaive < mint > pow ( suisen :: FPSNaive < mint > a , T b ) {
return a . pow ( b );
}
template < typename mint >
suisen :: FPSNaive < mint > inv ( suisen :: FPSNaive < mint > a ) {
return a . inv ();
}
#endif // SUISEN_FPS_NAIVE_OPS
#line 1 "library/polynomial/fps_naive.hpp"
#include <cassert>
#include <cmath>
#include <limits>
#include <type_traits>
#include <vector>
#line 1 "library/type_traits/type_traits.hpp"
#line 5 "library/type_traits/type_traits.hpp"
#include <iostream>
#line 7 "library/type_traits/type_traits.hpp"
namespace suisen {
template < typename ... Constraints > using constraints_t = std :: enable_if_t < std :: conjunction_v < Constraints ... > , std :: nullptr_t > ;
template < typename T , typename = std :: nullptr_t > struct bitnum { static constexpr int value = 0 ; };
template < typename T > struct bitnum < T , constraints_t < std :: is_integral < T >>> { static constexpr int value = std :: numeric_limits < std :: make_unsigned_t < T >>:: digits ; };
template < typename T > static constexpr int bitnum_v = bitnum < T >:: value ;
template < typename T , size_t n > struct is_nbit { static constexpr bool value = bitnum_v < T > == n ; };
template < typename T , size_t n > static constexpr bool is_nbit_v = is_nbit < T , n >:: value ;
template < typename T , typename = std :: nullptr_t > struct safely_multipliable { using type = T ; };
template < typename T > struct safely_multipliable < T , constraints_t < std :: is_signed < T > , is_nbit < T , 32 >>> { using type = long long ; };
template < typename T > struct safely_multipliable < T , constraints_t < std :: is_signed < T > , is_nbit < T , 64 >>> { using type = __int128_t ; };
template < typename T > struct safely_multipliable < T , constraints_t < std :: is_unsigned < T > , is_nbit < T , 32 >>> { using type = unsigned long long ; };
template < typename T > struct safely_multipliable < T , constraints_t < std :: is_unsigned < T > , is_nbit < T , 64 >>> { using type = __uint128_t ; };
template < typename T > using safely_multipliable_t = typename safely_multipliable < T >:: type ;
template < typename T , typename = void > struct rec_value_type { using type = T ; };
template < typename T > struct rec_value_type < T , std :: void_t < typename T :: value_type >> {
using type = typename rec_value_type < typename T :: value_type >:: type ;
};
template < typename T > using rec_value_type_t = typename rec_value_type < T >:: type ;
template < typename T > class is_iterable {
template < typename T_ > static auto test ( T_ e ) -> decltype ( e . begin (), e . end (), std :: true_type {});
static std :: false_type test (...);
public:
static constexpr bool value = decltype ( test ( std :: declval < T > ())) :: value ;
};
template < typename T > static constexpr bool is_iterable_v = is_iterable < T >:: value ;
template < typename T > class is_writable {
template < typename T_ > static auto test ( T_ e ) -> decltype ( std :: declval < std :: ostream &> () << e , std :: true_type {});
static std :: false_type test (...);
public:
static constexpr bool value = decltype ( test ( std :: declval < T > ())) :: value ;
};
template < typename T > static constexpr bool is_writable_v = is_writable < T >:: value ;
template < typename T > class is_readable {
template < typename T_ > static auto test ( T_ e ) -> decltype ( std :: declval < std :: istream &> () >> e , std :: true_type {});
static std :: false_type test (...);
public:
static constexpr bool value = decltype ( test ( std :: declval < T > ())) :: value ;
};
template < typename T > static constexpr bool is_readable_v = is_readable < T >:: value ;
} // namespace suisen
#line 11 "library/polynomial/fps_naive.hpp"
#line 1 "library/math/modint_extension.hpp"
#line 5 "library/math/modint_extension.hpp"
#include <optional>
/**
* refernce: https://37zigen.com/tonelli-shanks-algorithm/
* calculates x s.t. x^2 = a mod p in O((log p)^2).
*/
template < typename mint >
std :: optional < mint > safe_sqrt ( mint a ) {
static int p = mint :: mod ();
if ( a == 0 ) return std :: make_optional ( 0 );
if ( p == 2 ) return std :: make_optional ( a );
if ( a . pow (( p - 1 ) / 2 ) != 1 ) return std :: nullopt ;
mint b = 1 ;
while ( b . pow (( p - 1 ) / 2 ) == 1 ) ++ b ;
static int tlz = __builtin_ctz ( p - 1 ), q = ( p - 1 ) >> tlz ;
mint x = a . pow (( q + 1 ) / 2 );
b = b . pow ( q );
for ( int shift = 2 ; x * x != a ; ++ shift ) {
mint e = a . inv () * x * x ;
if ( e . pow ( 1 << ( tlz - shift )) != 1 ) x *= b ;
b *= b ;
}
return std :: make_optional ( x );
}
/**
* calculates x s.t. x^2 = a mod p in O((log p)^2).
* if not exists, raises runtime error.
*/
template < typename mint >
auto sqrt ( mint a ) -> decltype ( mint :: mod (), mint ()) {
return * safe_sqrt ( a );
}
template < typename mint >
auto log ( mint a ) -> decltype ( mint :: mod (), mint ()) {
assert ( a == 1 );
return 0 ;
}
template < typename mint >
auto exp ( mint a ) -> decltype ( mint :: mod (), mint ()) {
assert ( a == 0 );
return 1 ;
}
template < typename mint , typename T >
auto pow ( mint a , T b ) -> decltype ( mint :: mod (), mint ()) {
return a . pow ( b );
}
template < typename mint >
auto inv ( mint a ) -> decltype ( mint :: mod (), mint ()) {
return a . inv ();
}
#line 1 "library/math/inv_mods.hpp"
#line 5 "library/math/inv_mods.hpp"
namespace suisen {
template < typename mint >
class inv_mods {
public:
inv_mods () = default ;
inv_mods ( int n ) { ensure ( n ); }
const mint & operator []( int i ) const {
ensure ( i );
return invs [ i ];
}
static void ensure ( int n ) {
int sz = invs . size ();
if ( sz < 2 ) invs = { 0 , 1 }, sz = 2 ;
if ( sz < n + 1 ) {
invs . resize ( n + 1 );
for ( int i = sz ; i <= n ; ++ i ) invs [ i ] = mint ( mod - mod / i ) * invs [ mod % i ];
}
}
private:
static std :: vector < mint > invs ;
static constexpr int mod = mint :: mod ();
};
template < typename mint >
std :: vector < mint > inv_mods < mint >:: invs {};
template < typename mint >
std :: vector < mint > get_invs ( const std :: vector < mint >& vs ) {
const int n = vs . size ();
mint p = 1 ;
for ( auto & e : vs ) {
p *= e ;
assert ( e != 0 );
}
mint ip = p . inv ();
std :: vector < mint > rp ( n + 1 );
rp [ n ] = 1 ;
for ( int i = n - 1 ; i >= 0 ; -- i ) {
rp [ i ] = rp [ i + 1 ] * vs [ i ];
}
std :: vector < mint > res ( n );
for ( int i = 0 ; i < n ; ++ i ) {
res [ i ] = ip * rp [ i + 1 ];
ip *= vs [ i ];
}
return res ;
}
}
#line 14 "library/polynomial/fps_naive.hpp"
namespace suisen {
template < typename T >
struct FPSNaive : std :: vector < T > {
static inline int MAX_SIZE = std :: numeric_limits < int >:: max () / 2 ;
using value_type = T ;
using element_type = rec_value_type_t < T > ;
using std :: vector < value_type >:: vector ;
FPSNaive ( const std :: initializer_list < value_type > l ) : std :: vector < value_type >:: vector ( l ) {}
FPSNaive ( const std :: vector < value_type >& v ) : std :: vector < value_type >:: vector ( v ) {}
static void set_max_size ( int n ) {
FPSNaive < T >:: MAX_SIZE = n ;
}
const value_type operator []( int n ) const {
return n <= deg () ? unsafe_get ( n ) : value_type { 0 };
}
value_type & operator []( int n ) {
return ensure_deg ( n ), unsafe_get ( n );
}
int size () const {
return std :: vector < value_type >:: size ();
}
int deg () const {
return size () - 1 ;
}
int normalize () {
while ( size () and this -> back () == value_type { 0 }) this -> pop_back ();
return deg ();
}
FPSNaive & cut_inplace ( int n ) {
if ( size () > n ) this -> resize ( std :: max ( 0 , n ));
return * this ;
}
FPSNaive cut ( int n ) const {
FPSNaive f = FPSNaive ( * this ). cut_inplace ( n );
return f ;
}
FPSNaive operator + () const {
return FPSNaive ( * this );
}
FPSNaive operator - () const {
FPSNaive f ( * this );
for ( auto & e : f ) e = - e ;
return f ;
}
FPSNaive & operator ++ () { return ++ ( * this )[ 0 ], * this ; }
FPSNaive & operator -- () { return -- ( * this )[ 0 ], * this ; }
FPSNaive & operator += ( const value_type x ) { return ( * this )[ 0 ] += x , * this ; }
FPSNaive & operator -= ( const value_type x ) { return ( * this )[ 0 ] -= x , * this ; }
FPSNaive & operator += ( const FPSNaive & g ) {
ensure_deg ( g . deg ());
for ( int i = 0 ; i <= g . deg (); ++ i ) unsafe_get ( i ) += g . unsafe_get ( i );
return * this ;
}
FPSNaive & operator -= ( const FPSNaive & g ) {
ensure_deg ( g . deg ());
for ( int i = 0 ; i <= g . deg (); ++ i ) unsafe_get ( i ) -= g . unsafe_get ( i );
return * this ;
}
FPSNaive & operator *= ( const FPSNaive & g ) { return * this = * this * g ; }
FPSNaive & operator *= ( const value_type x ) {
for ( auto & e : * this ) e *= x ;
return * this ;
}
FPSNaive & operator /= ( const FPSNaive & g ) { return * this = * this / g ; }
FPSNaive & operator %= ( const FPSNaive & g ) { return * this = * this % g ; }
FPSNaive & operator <<= ( const int shamt ) {
this -> insert ( this -> begin (), shamt , value_type { 0 });
return * this ;
}
FPSNaive & operator >>= ( const int shamt ) {
if ( shamt > size ()) this -> clear ();
else this -> erase ( this -> begin (), this -> begin () + shamt );
return * this ;
}
friend FPSNaive operator + ( FPSNaive f , const FPSNaive & g ) { f += g ; return f ; }
friend FPSNaive operator + ( FPSNaive f , const value_type & x ) { f += x ; return f ; }
friend FPSNaive operator - ( FPSNaive f , const FPSNaive & g ) { f -= g ; return f ; }
friend FPSNaive operator - ( FPSNaive f , const value_type & x ) { f -= x ; return f ; }
friend FPSNaive operator * ( const FPSNaive & f , const FPSNaive & g ) {
if ( f . empty () or g . empty ()) return FPSNaive {};
const int n = f . size (), m = g . size ();
FPSNaive h ( std :: min ( MAX_SIZE , n + m - 1 ));
for ( int i = 0 ; i < n ; ++ i ) for ( int j = 0 ; j < m ; ++ j ) {
if ( i + j >= MAX_SIZE ) break ;
h . unsafe_get ( i + j ) += f . unsafe_get ( i ) * g . unsafe_get ( j );
}
return h ;
}
friend FPSNaive operator * ( FPSNaive f , const value_type & x ) { f *= x ; return f ; }
friend FPSNaive operator / ( FPSNaive f , const FPSNaive & g ) { return std :: move ( f . div_mod ( g ). first ); }
friend FPSNaive operator % ( FPSNaive f , const FPSNaive & g ) { return std :: move ( f . div_mod ( g ). second ); }
friend FPSNaive operator * ( const value_type x , FPSNaive f ) { f *= x ; return f ; }
friend FPSNaive operator << ( FPSNaive f , const int shamt ) { f <<= shamt ; return f ; }
friend FPSNaive operator >> ( FPSNaive f , const int shamt ) { f >>= shamt ; return f ; }
std :: pair < FPSNaive , FPSNaive > div_mod ( FPSNaive g ) const {
FPSNaive f = * this ;
const int fd = f . normalize (), gd = g . normalize ();
assert ( gd >= 0 );
if ( fd < gd ) return { FPSNaive {}, f };
if ( gd == 0 ) return { f *= g . unsafe_get ( 0 ). inv (), FPSNaive {} };
const int k = f . deg () - gd ;
value_type head_inv = g . unsafe_get ( gd ). inv ();
FPSNaive q ( k + 1 );
for ( int i = k ; i >= 0 ; -- i ) {
value_type div = f . unsafe_get ( i + gd ) * head_inv ;
q . unsafe_get ( i ) = div ;
for ( int j = 0 ; j <= gd ; ++ j ) f . unsafe_get ( i + j ) -= div * g . unsafe_get ( j );
}
f . cut_inplace ( gd );
f . normalize ();
return { q , f };
}
friend bool operator == ( const FPSNaive & f , const FPSNaive & g ) {
const int n = f . size (), m = g . size ();
if ( n < m ) return g == f ;
for ( int i = 0 ; i < m ; ++ i ) if ( f . unsafe_get ( i ) != g . unsafe_get ( i )) return false ;
for ( int i = m ; i < n ; ++ i ) if ( f . unsafe_get ( i ) != 0 ) return false ;
return true ;
}
friend bool operator != ( const FPSNaive & f , const FPSNaive & g ) {
return not ( f == g );
}
FPSNaive mul ( const FPSNaive & g , int n = - 1 ) const {
if ( n < 0 ) n = size ();
if ( this -> empty () or g . empty ()) return FPSNaive {};
const int m = size (), k = g . size ();
FPSNaive h ( std :: min ( n , m + k - 1 ));
for ( int i = 0 ; i < m ; ++ i ) {
for ( int j = 0 , jr = std :: min ( k , n - i ); j < jr ; ++ j ) {
h . unsafe_get ( i + j ) += unsafe_get ( i ) * g . unsafe_get ( j );
}
}
return h ;
}
FPSNaive diff () const {
if ( this -> empty ()) return {};
FPSNaive g ( size () - 1 );
for ( int i = 1 ; i <= deg (); ++ i ) g . unsafe_get ( i - 1 ) = unsafe_get ( i ) * i ;
return g ;
}
FPSNaive intg () const {
const int n = size ();
FPSNaive g ( n + 1 );
for ( int i = 0 ; i < n ; ++ i ) g . unsafe_get ( i + 1 ) = unsafe_get ( i ) * invs [ i + 1 ];
if ( g . deg () > MAX_SIZE ) g . cut_inplace ( MAX_SIZE );
return g ;
}
FPSNaive inv ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
FPSNaive g ( n );
const value_type inv_f0 = :: inv ( unsafe_get ( 0 ));
g . unsafe_get ( 0 ) = inv_f0 ;
for ( int i = 1 ; i < n ; ++ i ) {
for ( int j = 1 ; j <= i ; ++ j ) g . unsafe_get ( i ) -= g . unsafe_get ( i - j ) * ( * this )[ j ];
g . unsafe_get ( i ) *= inv_f0 ;
}
return g ;
}
FPSNaive exp ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
assert ( unsafe_get ( 0 ) == value_type { 0 });
FPSNaive g ( n );
g . unsafe_get ( 0 ) = value_type { 1 };
for ( int i = 1 ; i < n ; ++ i ) {
for ( int j = 1 ; j <= i ; ++ j ) g . unsafe_get ( i ) += j * g . unsafe_get ( i - j ) * ( * this )[ j ];
g . unsafe_get ( i ) *= invs [ i ];
}
return g ;
}
FPSNaive log ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
assert ( unsafe_get ( 0 ) == value_type { 1 });
FPSNaive g ( n );
g . unsafe_get ( 0 ) = value_type { 0 };
for ( int i = 1 ; i < n ; ++ i ) {
g . unsafe_get ( i ) = i * ( * this )[ i ];
for ( int j = 1 ; j < i ; ++ j ) g . unsafe_get ( i ) -= ( i - j ) * g . unsafe_get ( i - j ) * ( * this )[ j ];
g . unsafe_get ( i ) *= invs [ i ];
}
return g ;
}
FPSNaive pow ( const long long k , int n = - 1 ) const {
if ( n < 0 ) n = size ();
if ( k == 0 ) {
FPSNaive res ( n );
res [ 0 ] = 1 ;
return res ;
}
int z = 0 ;
while ( z < size () and unsafe_get ( z ) == value_type { 0 }) ++ z ;
if ( z == size () or z > ( n - 1 ) / k ) return FPSNaive ( n , 0 );
const int m = n - z * k ;
FPSNaive g ( m );
const value_type inv_f0 = :: inv ( unsafe_get ( z ));
g . unsafe_get ( 0 ) = unsafe_get ( z ). pow ( k );
for ( int i = 1 ; i < m ; ++ i ) {
for ( int j = 1 ; j <= i ; ++ j ) g . unsafe_get ( i ) += ( element_type { k } * j - ( i - j )) * g . unsafe_get ( i - j ) * ( * this )[ z + j ];
g . unsafe_get ( i ) *= inv_f0 * invs [ i ];
}
g <<= z * k ;
return g ;
}
std :: optional < FPSNaive > safe_sqrt ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
int dl = 0 ;
while ( dl < size () and unsafe_get ( dl ) == value_type { 0 }) ++ dl ;
if ( dl == size ()) return FPSNaive ( n , 0 );
if ( dl & 1 ) return std :: nullopt ;
const int m = n - dl / 2 ;
FPSNaive g ( m );
auto opt_g0 = :: safe_sqrt (( * this )[ dl ]);
if ( not opt_g0 . has_value ()) return std :: nullopt ;
g . unsafe_get ( 0 ) = * opt_g0 ;
value_type inv_2g0 = :: inv ( 2 * g . unsafe_get ( 0 ));
for ( int i = 1 ; i < m ; ++ i ) {
g . unsafe_get ( i ) = ( * this )[ dl + i ];
for ( int j = 1 ; j < i ; ++ j ) g . unsafe_get ( i ) -= g . unsafe_get ( j ) * g . unsafe_get ( i - j );
g . unsafe_get ( i ) *= inv_2g0 ;
}
g <<= dl / 2 ;
return g ;
}
FPSNaive sqrt ( int n = - 1 ) const {
if ( n < 0 ) n = size ();
return * safe_sqrt ( n );
}
value_type eval ( value_type x ) const {
value_type y = 0 ;
for ( int i = size () - 1 ; i >= 0 ; -- i ) y = y * x + unsafe_get ( i );
return y ;
}
private:
static inline inv_mods < element_type > invs ;
void ensure_deg ( int d ) {
if ( deg () < d ) this -> resize ( d + 1 , value_type { 0 });
}
const value_type & unsafe_get ( int i ) const {
return std :: vector < value_type >:: operator []( i );
}
value_type & unsafe_get ( int i ) {
return std :: vector < value_type >:: operator []( i );
}
};
} // namespace suisen
template < typename mint >
suisen :: FPSNaive < mint > sqrt ( suisen :: FPSNaive < mint > a ) {
return a . sqrt ();
}
template < typename mint >
suisen :: FPSNaive < mint > log ( suisen :: FPSNaive < mint > a ) {
return a . log ();
}
template < typename mint >
suisen :: FPSNaive < mint > exp ( suisen :: FPSNaive < mint > a ) {
return a . exp ();
}
template < typename mint , typename T >
suisen :: FPSNaive < mint > pow ( suisen :: FPSNaive < mint > a , T b ) {
return a . pow ( b );
}
template < typename mint >
suisen :: FPSNaive < mint > inv ( suisen :: FPSNaive < mint > a ) {
return a . inv ();
}
Back to top page