modulo/quadratic-equation.hpp
Depends on
Code
#pragma once
#include "mod-sqrt.hpp"
template < typename mint >
vector < mint > QuadraticEquation ( mint a , mint b , mint c ) {
assert ( mint :: get_mod () % 2 != 0 );
if ( a == mint ()) {
if ( b == mint ()) {
assert ( c != mint ());
return {};
}
return vector < mint > { - c * b . inverse ()};
}
mint ia = a . inverse (), inv2 = mint ( 2 ). inverse ();
b *= ia , c *= ia ;
auto D = mod_sqrt ((( b * inv2 ). pow ( 2 ) - c ). get (), mint :: get_mod ());
if ( D == - 1 ) return {};
if ( D <= 1 ) return vector < mint > { - b * inv2 + D };
return vector < mint > { - b * inv2 + D , - b * inv2 - D };
}
#line 2 "modulo/quadratic-equation.hpp"
#line 2 "modint/arbitrary-montgomery-modint.hpp"
#include <iostream>
using namespace std ;
template < typename Int , typename UInt , typename Long , typename ULong , int id >
struct ArbitraryLazyMontgomeryModIntBase {
using mint = ArbitraryLazyMontgomeryModIntBase ;
inline static UInt mod ;
inline static UInt r ;
inline static UInt n2 ;
static constexpr int bit_length = sizeof ( UInt ) * 8 ;
static UInt get_r () {
UInt ret = mod ;
while ( mod * ret != 1 ) ret *= UInt ( 2 ) - mod * ret ;
return ret ;
}
static void set_mod ( UInt m ) {
assert ( m < ( UInt ( 1u ) << ( bit_length - 2 )));
assert (( m & 1 ) == 1 );
mod = m , n2 = - ULong ( m ) % m , r = get_r ();
}
UInt a ;
ArbitraryLazyMontgomeryModIntBase () : a ( 0 ) {}
ArbitraryLazyMontgomeryModIntBase ( const Long & b )
: a ( reduce ( ULong ( b % mod + mod ) * n2 )){};
static UInt reduce ( const ULong & b ) {
return ( b + ULong ( UInt ( b ) * UInt ( - r )) * mod ) >> bit_length ;
}
mint & operator += ( const mint & b ) {
if ( Int ( a += b . a - 2 * mod ) < 0 ) a += 2 * mod ;
return * this ;
}
mint & operator -= ( const mint & b ) {
if ( Int ( a -= b . a ) < 0 ) a += 2 * mod ;
return * this ;
}
mint & operator *= ( const mint & b ) {
a = reduce ( ULong ( a ) * b . a );
return * this ;
}
mint & operator /= ( const mint & b ) {
* this *= b . inverse ();
return * this ;
}
mint operator + ( const mint & b ) const { return mint ( * this ) += b ; }
mint operator - ( const mint & b ) const { return mint ( * this ) -= b ; }
mint operator * ( const mint & b ) const { return mint ( * this ) *= b ; }
mint operator / ( const mint & b ) const { return mint ( * this ) /= b ; }
bool operator == ( const mint & b ) const {
return ( a >= mod ? a - mod : a ) == ( b . a >= mod ? b . a - mod : b . a );
}
bool operator != ( const mint & b ) const {
return ( a >= mod ? a - mod : a ) != ( b . a >= mod ? b . a - mod : b . a );
}
mint operator - () const { return mint ( 0 ) - mint ( * this ); }
mint operator + () const { return mint ( * this ); }
mint pow ( ULong n ) const {
mint ret ( 1 ), mul ( * this );
while ( n > 0 ) {
if ( n & 1 ) ret *= mul ;
mul *= mul , n >>= 1 ;
}
return ret ;
}
friend ostream & operator << ( ostream & os , const mint & b ) {
return os << b . get ();
}
friend istream & operator >> ( istream & is , mint & b ) {
Long t ;
is >> t ;
b = ArbitraryLazyMontgomeryModIntBase ( t );
return ( is );
}
mint inverse () const {
Int x = get (), y = get_mod (), u = 1 , v = 0 ;
while ( y > 0 ) {
Int t = x / y ;
swap ( x -= t * y , y );
swap ( u -= t * v , v );
}
return mint { u };
}
UInt get () const {
UInt ret = reduce ( a );
return ret >= mod ? ret - mod : ret ;
}
static UInt get_mod () { return mod ; }
};
// id に適当な乱数を割り当てて使う
template < int id >
using ArbitraryLazyMontgomeryModInt =
ArbitraryLazyMontgomeryModIntBase < int , unsigned int , long long ,
unsigned long long , id > ;
template < int id >
using ArbitraryLazyMontgomeryModInt64bit =
ArbitraryLazyMontgomeryModIntBase < long long , unsigned long long , __int128_t ,
__uint128_t , id > ;
#line 3 "modulo/mod-sqrt.hpp"
int64_t mod_sqrt ( const int64_t & a , const int64_t & p ) {
assert ( 0 <= a && a < p );
if ( a < 2 ) return a ;
using Mint = ArbitraryLazyMontgomeryModInt < 409075245 > ;
Mint :: set_mod ( p );
if ( Mint ( a ). pow (( p - 1 ) >> 1 ) != 1 ) return - 1 ;
Mint b = 1 , one = 1 ;
while ( b . pow (( p - 1 ) >> 1 ) == 1 ) b += one ;
int64_t m = p - 1 , e = 0 ;
while ( m % 2 == 0 ) m >>= 1 , e += 1 ;
Mint x = Mint ( a ). pow (( m - 1 ) >> 1 );
Mint y = Mint ( a ) * x * x ;
x *= a ;
Mint z = Mint ( b ). pow ( m );
while ( y != 1 ) {
int64_t j = 0 ;
Mint t = y ;
while ( t != one ) {
j += 1 ;
t *= t ;
}
z = z . pow ( int64_t ( 1 ) << ( e - j - 1 ));
x *= z ;
z *= z ;
y *= z ;
e = j ;
}
return x . get ();
}
/**
* @brief mod sqrt(Tonelli-Shanks algorithm)
* @docs docs/modulo/mod-sqrt.md
*/
#line 4 "modulo/quadratic-equation.hpp"
template < typename mint >
vector < mint > QuadraticEquation ( mint a , mint b , mint c ) {
assert ( mint :: get_mod () % 2 != 0 );
if ( a == mint ()) {
if ( b == mint ()) {
assert ( c != mint ());
return {};
}
return vector < mint > { - c * b . inverse ()};
}
mint ia = a . inverse (), inv2 = mint ( 2 ). inverse ();
b *= ia , c *= ia ;
auto D = mod_sqrt ((( b * inv2 ). pow ( 2 ) - c ). get (), mint :: get_mod ());
if ( D == - 1 ) return {};
if ( D <= 1 ) return vector < mint > { - b * inv2 + D };
return vector < mint > { - b * inv2 + D , - b * inv2 - D };
}
Back to top page