fps/polynomial-interpolation.hpp
Depends on
Required by
Verified with
Code
#pragma once
#include "./formal-power-series.hpp"
#include "./multipoint-evaluation.hpp"
template < class mint >
FormalPowerSeries < mint > PolynomialInterpolation ( const vector < mint > & xs ,
const vector < mint > & ys ) {
using fps = FormalPowerSeries < mint > ;
assert ( xs . size () == ys . size ());
ProductTree < mint > ptree ( xs );
fps w = ptree . buf [ 1 ]. diff ();
vector < mint > vs = InnerMultipointEvaluation < mint > ( w , xs , ptree );
auto rec = [ & ]( auto self , int idx ) -> fps {
if ( idx >= ptree . N ) {
if ( idx - ptree . N < ( int ) xs . size ())
return { ys [ idx - ptree . N ] / vs [ idx - ptree . N ]};
else
return { mint ( 1 )};
}
if ( ptree . buf [ idx << 1 | 0 ]. empty ())
return {};
else if ( ptree . buf [ idx << 1 | 1 ]. empty ())
return self ( self , idx << 1 | 0 );
return self ( self , idx << 1 | 0 ) * ptree . buf [ idx << 1 | 1 ] +
self ( self , idx << 1 | 1 ) * ptree . buf [ idx << 1 | 0 ];
};
return rec ( rec , 1 );
}
#line 2 "fps/polynomial-interpolation.hpp"
#line 2 "fps/formal-power-series.hpp"
template < typename mint >
struct FormalPowerSeries : vector < mint > {
using vector < mint >:: vector ;
using FPS = FormalPowerSeries ;
FPS & operator += ( const FPS & r ) {
if ( r . size () > this -> size ()) this -> resize ( r . size ());
for ( int i = 0 ; i < ( int ) r . size (); i ++ ) ( * this )[ i ] += r [ i ];
return * this ;
}
FPS & operator += ( const mint & r ) {
if ( this -> empty ()) this -> resize ( 1 );
( * this )[ 0 ] += r ;
return * this ;
}
FPS & operator -= ( const FPS & r ) {
if ( r . size () > this -> size ()) this -> resize ( r . size ());
for ( int i = 0 ; i < ( int ) r . size (); i ++ ) ( * this )[ i ] -= r [ i ];
return * this ;
}
FPS & operator -= ( const mint & r ) {
if ( this -> empty ()) this -> resize ( 1 );
( * this )[ 0 ] -= r ;
return * this ;
}
FPS & operator *= ( const mint & v ) {
for ( int k = 0 ; k < ( int ) this -> size (); k ++ ) ( * this )[ k ] *= v ;
return * this ;
}
FPS & operator /= ( const FPS & r ) {
if ( this -> size () < r . size ()) {
this -> clear ();
return * this ;
}
int n = this -> size () - r . size () + 1 ;
if (( int ) r . size () <= 64 ) {
FPS f ( * this ), g ( r );
g . shrink ();
mint coeff = g . back (). inverse ();
for ( auto & x : g ) x *= coeff ;
int deg = ( int ) f . size () - ( int ) g . size () + 1 ;
int gs = g . size ();
FPS quo ( deg );
for ( int i = deg - 1 ; i >= 0 ; i -- ) {
quo [ i ] = f [ i + gs - 1 ];
for ( int j = 0 ; j < gs ; j ++ ) f [ i + j ] -= quo [ i ] * g [ j ];
}
* this = quo * coeff ;
this -> resize ( n , mint ( 0 ));
return * this ;
}
return * this = (( * this ). rev (). pre ( n ) * r . rev (). inv ( n )). pre ( n ). rev ();
}
FPS & operator %= ( const FPS & r ) {
* this -= * this / r * r ;
shrink ();
return * this ;
}
FPS operator + ( const FPS & r ) const { return FPS ( * this ) += r ; }
FPS operator + ( const mint & v ) const { return FPS ( * this ) += v ; }
FPS operator - ( const FPS & r ) const { return FPS ( * this ) -= r ; }
FPS operator - ( const mint & v ) const { return FPS ( * this ) -= v ; }
FPS operator * ( const FPS & r ) const { return FPS ( * this ) *= r ; }
FPS operator * ( const mint & v ) const { return FPS ( * this ) *= v ; }
FPS operator / ( const FPS & r ) const { return FPS ( * this ) /= r ; }
FPS operator % ( const FPS & r ) const { return FPS ( * this ) %= r ; }
FPS operator - () const {
FPS ret ( this -> size ());
for ( int i = 0 ; i < ( int ) this -> size (); i ++ ) ret [ i ] = - ( * this )[ i ];
return ret ;
}
void shrink () {
while ( this -> size () && this -> back () == mint ( 0 )) this -> pop_back ();
}
FPS rev () const {
FPS ret ( * this );
reverse ( begin ( ret ), end ( ret ));
return ret ;
}
FPS dot ( FPS r ) const {
FPS ret ( min ( this -> size (), r . size ()));
for ( int i = 0 ; i < ( int ) ret . size (); i ++ ) ret [ i ] = ( * this )[ i ] * r [ i ];
return ret ;
}
// 前 sz 項を取ってくる。sz に足りない項は 0 埋めする
FPS pre ( int sz ) const {
FPS ret ( begin ( * this ), begin ( * this ) + min (( int ) this -> size (), sz ));
if (( int ) ret . size () < sz ) ret . resize ( sz );
return ret ;
}
FPS operator >> ( int sz ) const {
if (( int ) this -> size () <= sz ) return {};
FPS ret ( * this );
ret . erase ( ret . begin (), ret . begin () + sz );
return ret ;
}
FPS operator << ( int sz ) const {
FPS ret ( * this );
ret . insert ( ret . begin (), sz , mint ( 0 ));
return ret ;
}
FPS diff () const {
const int n = ( int ) this -> size ();
FPS ret ( max ( 0 , n - 1 ));
mint one ( 1 ), coeff ( 1 );
for ( int i = 1 ; i < n ; i ++ ) {
ret [ i - 1 ] = ( * this )[ i ] * coeff ;
coeff += one ;
}
return ret ;
}
FPS integral () const {
const int n = ( int ) this -> size ();
FPS ret ( n + 1 );
ret [ 0 ] = mint ( 0 );
if ( n > 0 ) ret [ 1 ] = mint ( 1 );
auto mod = mint :: get_mod ();
for ( int i = 2 ; i <= n ; i ++ ) ret [ i ] = ( - ret [ mod % i ]) * ( mod / i );
for ( int i = 0 ; i < n ; i ++ ) ret [ i + 1 ] *= ( * this )[ i ];
return ret ;
}
mint eval ( mint x ) const {
mint r = 0 , w = 1 ;
for ( auto & v : * this ) r += w * v , w *= x ;
return r ;
}
FPS log ( int deg = - 1 ) const {
assert ( ! ( * this ). empty () && ( * this )[ 0 ] == mint ( 1 ));
if ( deg == - 1 ) deg = ( int ) this -> size ();
return ( this -> diff () * this -> inv ( deg )). pre ( deg - 1 ). integral ();
}
FPS pow ( int64_t k , int deg = - 1 ) const {
const int n = ( int ) this -> size ();
if ( deg == - 1 ) deg = n ;
if ( k == 0 ) {
FPS ret ( deg );
if ( deg ) ret [ 0 ] = 1 ;
return ret ;
}
for ( int i = 0 ; i < n ; i ++ ) {
if (( * this )[ i ] != mint ( 0 )) {
mint rev = mint ( 1 ) / ( * this )[ i ];
FPS ret = ((( * this * rev ) >> i ). log ( deg ) * k ). exp ( deg );
ret *= ( * this )[ i ]. pow ( k );
ret = ( ret << ( i * k )). pre ( deg );
if (( int ) ret . size () < deg ) ret . resize ( deg , mint ( 0 ));
return ret ;
}
if ( __int128_t ( i + 1 ) * k >= deg ) return FPS ( deg , mint ( 0 ));
}
return FPS ( deg , mint ( 0 ));
}
static void * ntt_ptr ;
static void set_fft ();
FPS & operator *= ( const FPS & r );
void ntt ();
void intt ();
void ntt_doubling ();
static int ntt_pr ();
FPS inv ( int deg = - 1 ) const ;
FPS exp ( int deg = - 1 ) const ;
};
template < typename mint >
void * FormalPowerSeries < mint >:: ntt_ptr = nullptr ;
/**
* @brief 多項式/形式的冪級数ライブラリ
* @docs docs/fps/formal-power-series.md
*/
#line 2 "fps/multipoint-evaluation.hpp"
#line 4 "fps/multipoint-evaluation.hpp"
template < typename mint >
struct ProductTree {
using fps = FormalPowerSeries < mint > ;
const vector < mint > & xs ;
vector < fps > buf ;
int N , xsz ;
vector < int > l , r ;
ProductTree ( const vector < mint > & xs_ ) : xs ( xs_ ), xsz ( xs . size ()) {
N = 1 ;
while ( N < ( int ) xs . size ()) N *= 2 ;
buf . resize ( 2 * N );
l . resize ( 2 * N , xs . size ());
r . resize ( 2 * N , xs . size ());
fps :: set_fft ();
if ( fps :: ntt_ptr == nullptr )
build ();
else
build_ntt ();
}
void build () {
for ( int i = 0 ; i < xsz ; i ++ ) {
l [ i + N ] = i ;
r [ i + N ] = i + 1 ;
buf [ i + N ] = { - xs [ i ], 1 };
}
for ( int i = N - 1 ; i > 0 ; i -- ) {
l [ i ] = l [( i << 1 ) | 0 ];
r [ i ] = r [( i << 1 ) | 1 ];
if ( buf [( i << 1 ) | 0 ]. empty ())
continue ;
else if ( buf [( i << 1 ) | 1 ]. empty ())
buf [ i ] = buf [( i << 1 ) | 0 ];
else
buf [ i ] = buf [( i << 1 ) | 0 ] * buf [( i << 1 ) | 1 ];
}
}
void build_ntt () {
fps f ;
f . reserve ( N * 2 );
for ( int i = 0 ; i < xsz ; i ++ ) {
l [ i + N ] = i ;
r [ i + N ] = i + 1 ;
buf [ i + N ] = { - xs [ i ] + 1 , - xs [ i ] - 1 };
}
for ( int i = N - 1 ; i > 0 ; i -- ) {
l [ i ] = l [( i << 1 ) | 0 ];
r [ i ] = r [( i << 1 ) | 1 ];
if ( buf [( i << 1 ) | 0 ]. empty ())
continue ;
else if ( buf [( i << 1 ) | 1 ]. empty ())
buf [ i ] = buf [( i << 1 ) | 0 ];
else if ( buf [( i << 1 ) | 0 ]. size () == buf [( i << 1 ) | 1 ]. size ()) {
buf [ i ] = buf [( i << 1 ) | 0 ];
f . clear ();
copy ( begin ( buf [( i << 1 ) | 1 ]), end ( buf [( i << 1 ) | 1 ]),
back_inserter ( f ));
buf [ i ]. ntt_doubling ();
f . ntt_doubling ();
for ( int j = 0 ; j < ( int ) buf [ i ]. size (); j ++ ) buf [ i ][ j ] *= f [ j ];
} else {
buf [ i ] = buf [( i << 1 ) | 0 ];
f . clear ();
copy ( begin ( buf [( i << 1 ) | 1 ]), end ( buf [( i << 1 ) | 1 ]),
back_inserter ( f ));
buf [ i ]. ntt_doubling ();
f . intt ();
f . resize ( buf [ i ]. size (), mint ( 0 ));
f . ntt ();
for ( int j = 0 ; j < ( int ) buf [ i ]. size (); j ++ ) buf [ i ][ j ] *= f [ j ];
}
}
for ( int i = 0 ; i < 2 * N ; i ++ ) {
buf [ i ]. intt ();
buf [ i ]. shrink ();
}
}
};
template < typename mint >
vector < mint > InnerMultipointEvaluation ( const FormalPowerSeries < mint > & f ,
const vector < mint > & xs ,
const ProductTree < mint > & ptree ) {
using fps = FormalPowerSeries < mint > ;
vector < mint > ret ;
ret . reserve ( xs . size ());
auto rec = [ & ]( auto self , fps a , int idx ) {
if ( ptree . l [ idx ] == ptree . r [ idx ]) return ;
a %= ptree . buf [ idx ];
if (( int ) a . size () <= 64 ) {
for ( int i = ptree . l [ idx ]; i < ptree . r [ idx ]; i ++ )
ret . push_back ( a . eval ( xs [ i ]));
return ;
}
self ( self , a , ( idx << 1 ) | 0 );
self ( self , a , ( idx << 1 ) | 1 );
};
rec ( rec , f , 1 );
return ret ;
}
template < typename mint >
vector < mint > MultipointEvaluation ( const FormalPowerSeries < mint > & f ,
const vector < mint > & xs ) {
if ( f . empty () || xs . empty ()) return vector < mint > ( xs . size (), mint ( 0 ));
return InnerMultipointEvaluation ( f , xs , ProductTree < mint > ( xs ));
}
/**
* @brief Multipoint Evaluation
*/
#line 5 "fps/polynomial-interpolation.hpp"
template < class mint >
FormalPowerSeries < mint > PolynomialInterpolation ( const vector < mint > & xs ,
const vector < mint > & ys ) {
using fps = FormalPowerSeries < mint > ;
assert ( xs . size () == ys . size ());
ProductTree < mint > ptree ( xs );
fps w = ptree . buf [ 1 ]. diff ();
vector < mint > vs = InnerMultipointEvaluation < mint > ( w , xs , ptree );
auto rec = [ & ]( auto self , int idx ) -> fps {
if ( idx >= ptree . N ) {
if ( idx - ptree . N < ( int ) xs . size ())
return { ys [ idx - ptree . N ] / vs [ idx - ptree . N ]};
else
return { mint ( 1 )};
}
if ( ptree . buf [ idx << 1 | 0 ]. empty ())
return {};
else if ( ptree . buf [ idx << 1 | 1 ]. empty ())
return self ( self , idx << 1 | 0 );
return self ( self , idx << 1 | 0 ) * ptree . buf [ idx << 1 | 1 ] +
self ( self , idx << 1 | 1 ) * ptree . buf [ idx << 1 | 0 ];
};
return rec ( rec , 1 );
}
Back to top page