#include "fps/root-finding.hpp"
#pragma once #include <random> #include <vector> using namespace std; #include "formal-power-series.hpp" #include "mod-pow.hpp" #include "polynomial-gcd.hpp" template <typename mint> vector<mint> root_finding(const FormalPowerSeries<mint>& f) { using fps = FormalPowerSeries<mint>; long long p = mint::get_mod(); vector<mint> ans; if (p == 2) { for (int i = 0; i < 2; i++) { if (f.eval(i) == 0) ans.push_back(i); } return ans; } vector<fps> fs; fs.push_back(PolyGCD(mod_pow(p, fps{0, 1}, f) - fps{0, 1}, f)); mt19937_64 rng(58); while (!fs.empty()) { auto g = fs.back(); fs.pop_back(); if (g.size() == 2) ans.push_back(-g[0]); if (g.size() <= 2) continue; fps s = fps{(long long)(rng() % p), 1}; fps t = PolyGCD(mod_pow((p - 1) / 2, s, g) - fps{1}, g); fs.push_back(t); if (g.size() != t.size()) fs.push_back(g / t); } return ans; }
#line 2 "fps/root-finding.hpp" #include <random> #include <vector> using namespace std; #line 2 "fps/formal-power-series.hpp" template <typename mint> struct FormalPowerSeries : vector<mint> { using vector<mint>::vector; using FPS = FormalPowerSeries; FPS &operator+=(const FPS &r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i]; return *this; } FPS &operator+=(const mint &r) { if (this->empty()) this->resize(1); (*this)[0] += r; return *this; } FPS &operator-=(const FPS &r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i]; return *this; } FPS &operator-=(const mint &r) { if (this->empty()) this->resize(1); (*this)[0] -= r; return *this; } FPS &operator*=(const mint &v) { for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v; return *this; } FPS &operator/=(const FPS &r) { if (this->size() < r.size()) { this->clear(); return *this; } int n = this->size() - r.size() + 1; if ((int)r.size() <= 64) { FPS f(*this), g(r); g.shrink(); mint coeff = g.back().inverse(); for (auto &x : g) x *= coeff; int deg = (int)f.size() - (int)g.size() + 1; int gs = g.size(); FPS quo(deg); for (int i = deg - 1; i >= 0; i--) { quo[i] = f[i + gs - 1]; for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j]; } *this = quo * coeff; this->resize(n, mint(0)); return *this; } return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev(); } FPS &operator%=(const FPS &r) { *this -= *this / r * r; shrink(); return *this; } FPS operator+(const FPS &r) const { return FPS(*this) += r; } FPS operator+(const mint &v) const { return FPS(*this) += v; } FPS operator-(const FPS &r) const { return FPS(*this) -= r; } FPS operator-(const mint &v) const { return FPS(*this) -= v; } FPS operator*(const FPS &r) const { return FPS(*this) *= r; } FPS operator*(const mint &v) const { return FPS(*this) *= v; } FPS operator/(const FPS &r) const { return FPS(*this) /= r; } FPS operator%(const FPS &r) const { return FPS(*this) %= r; } FPS operator-() const { FPS ret(this->size()); for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i]; return ret; } void shrink() { while (this->size() && this->back() == mint(0)) this->pop_back(); } FPS rev() const { FPS ret(*this); reverse(begin(ret), end(ret)); return ret; } FPS dot(FPS r) const { FPS ret(min(this->size(), r.size())); for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i]; return ret; } // 前 sz 項を取ってくる。sz に足りない項は 0 埋めする FPS pre(int sz) const { FPS ret(begin(*this), begin(*this) + min((int)this->size(), sz)); if ((int)ret.size() < sz) ret.resize(sz); return ret; } FPS operator>>(int sz) const { if ((int)this->size() <= sz) return {}; FPS ret(*this); ret.erase(ret.begin(), ret.begin() + sz); return ret; } FPS operator<<(int sz) const { FPS ret(*this); ret.insert(ret.begin(), sz, mint(0)); return ret; } FPS diff() const { const int n = (int)this->size(); FPS ret(max(0, n - 1)); mint one(1), coeff(1); for (int i = 1; i < n; i++) { ret[i - 1] = (*this)[i] * coeff; coeff += one; } return ret; } FPS integral() const { const int n = (int)this->size(); FPS ret(n + 1); ret[0] = mint(0); if (n > 0) ret[1] = mint(1); auto mod = mint::get_mod(); for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i); for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i]; return ret; } mint eval(mint x) const { mint r = 0, w = 1; for (auto &v : *this) r += w * v, w *= x; return r; } FPS log(int deg = -1) const { assert(!(*this).empty() && (*this)[0] == mint(1)); if (deg == -1) deg = (int)this->size(); return (this->diff() * this->inv(deg)).pre(deg - 1).integral(); } FPS pow(int64_t k, int deg = -1) const { const int n = (int)this->size(); if (deg == -1) deg = n; if (k == 0) { FPS ret(deg); if (deg) ret[0] = 1; return ret; } for (int i = 0; i < n; i++) { if ((*this)[i] != mint(0)) { mint rev = mint(1) / (*this)[i]; FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg); ret *= (*this)[i].pow(k); ret = (ret << (i * k)).pre(deg); if ((int)ret.size() < deg) ret.resize(deg, mint(0)); return ret; } if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0)); } return FPS(deg, mint(0)); } static void *ntt_ptr; static void set_fft(); FPS &operator*=(const FPS &r); void ntt(); void intt(); void ntt_doubling(); static int ntt_pr(); FPS inv(int deg = -1) const; FPS exp(int deg = -1) const; }; template <typename mint> void *FormalPowerSeries<mint>::ntt_ptr = nullptr; /** * @brief 多項式/形式的冪級数ライブラリ * @docs docs/fps/formal-power-series.md */ #line 2 "fps/mod-pow.hpp" #line 4 "fps/mod-pow.hpp" template <typename mint> FormalPowerSeries<mint> mod_pow(int64_t k, const FormalPowerSeries<mint>& base, const FormalPowerSeries<mint>& d) { using fps = FormalPowerSeries<mint>; assert(!d.empty()); auto inv = d.rev().inv(); auto quo = [&](const fps& poly) { if (poly.size() < d.size()) return fps{}; int n = poly.size() - d.size() + 1; return (poly.rev().pre(n) * inv.pre(n)).pre(n).rev(); }; fps res{1}, b(base); while (k) { if (k & 1) { res *= b; res -= quo(res) * d; res.shrink(); } b *= b; b -= quo(b) * d; b.shrink(); k >>= 1; assert(b.size() + 1 <= d.size()); assert(res.size() + 1 <= d.size()); } return res; } /** * @brief Mod-Pow ($f(x)^k \mod g(x)$) */ #line 2 "fps/polynomial-gcd.hpp" #line 4 "fps/polynomial-gcd.hpp" namespace poly_gcd { template <typename mint> using FPS = FormalPowerSeries<mint>; template <typename mint> using Arr = pair<FPS<mint>, FPS<mint>>; template <typename mint> struct Mat { using fps = FPS<mint>; fps a00, a01, a10, a11; Mat() = default; Mat(const fps& a00_, const fps& a01_, const fps& a10_, const fps& a11_) : a00(a00_), a01(a01_), a10(a10_), a11(a11_) {} Mat& operator*=(const Mat& r) { fps A00 = a00 * r.a00 + a01 * r.a10; fps A01 = a00 * r.a01 + a01 * r.a11; fps A10 = a10 * r.a00 + a11 * r.a10; fps A11 = a10 * r.a01 + a11 * r.a11; A00.shrink(); A01.shrink(); A10.shrink(); A11.shrink(); swap(A00, a00); swap(A01, a01); swap(A10, a10); swap(A11, a11); return *this; } static Mat I() { return Mat(fps{mint(1)}, fps(), fps(), fps{mint(1)}); } Mat operator*(const Mat& r) const { return Mat(*this) *= r; } }; template <typename mint> Arr<mint> operator*(const Mat<mint>& m, const Arr<mint>& a) { using fps = FPS<mint>; fps b0 = m.a00 * a.first + m.a01 * a.second; fps b1 = m.a10 * a.first + m.a11 * a.second; b0.shrink(); b1.shrink(); return {b0, b1}; }; template <typename mint> void InnerNaiveGCD(Mat<mint>& m, Arr<mint>& p) { using fps = FPS<mint>; fps quo = p.first / p.second; fps rem = p.first - p.second * quo; fps b10 = m.a00 - m.a10 * quo; fps b11 = m.a01 - m.a11 * quo; rem.shrink(); b10.shrink(); b11.shrink(); swap(b10, m.a10); swap(b11, m.a11); swap(b10, m.a00); swap(b11, m.a01); p = {p.second, rem}; } template <typename mint> Mat<mint> InnerHalfGCD(Arr<mint> p) { int n = p.first.size(), m = p.second.size(); int k = (n + 1) / 2; if (m <= k) return Mat<mint>::I(); Mat<mint> m1 = InnerHalfGCD(make_pair(p.first >> k, p.second >> k)); p = m1 * p; if ((int)p.second.size() <= k) return m1; InnerNaiveGCD(m1, p); if ((int)p.second.size() <= k) return m1; int l = (int)p.first.size() - 1; int j = 2 * k - l; p.first = p.first >> j; p.second = p.second >> j; return InnerHalfGCD(p) * m1; } template <typename mint> Mat<mint> InnerPolyGCD(const FPS<mint>& a, const FPS<mint>& b) { Arr<mint> p{a, b}; p.first.shrink(); p.second.shrink(); int n = p.first.size(), m = p.second.size(); if (n < m) { Mat<mint> mat = InnerPolyGCD(p.second, p.first); swap(mat.a00, mat.a01); swap(mat.a10, mat.a11); return mat; } Mat<mint> res = Mat<mint>::I(); while (1) { Mat<mint> m1 = InnerHalfGCD(p); p = m1 * p; if (p.second.empty()) return m1 * res; InnerNaiveGCD(m1, p); if (p.second.empty()) return m1 * res; res = m1 * res; } } // 多項式 GCD, 非零の場合 monic なものを返す template <typename mint> FPS<mint> PolyGCD(const FPS<mint>& a, const FPS<mint>& b) { Arr<mint> p(a, b); Mat<mint> m = InnerPolyGCD(a, b); p = m * p; if (!p.first.empty()) { mint coeff = p.first.back().inverse(); for (auto& x : p.first) x *= coeff; } return p.first; } template <typename mint> pair<int, FPS<mint>> PolyInv(const FPS<mint>& f, const FPS<mint>& g) { using fps = FPS<mint>; pair<fps, fps> p(f, g); Mat<mint> m = InnerPolyGCD(f, g); fps gcd_ = (m * p).first; if (gcd_.size() != 1) return {false, fps()}; pair<fps, fps> x(fps{mint(1)}, g); return {true, ((m * x).first % g) * gcd_[0].inverse()}; } } // namespace poly_gcd using poly_gcd::PolyGCD; using poly_gcd::PolyInv; /** * @brief 多項式GCD * @docs docs/fps/polynomial-gcd.md */ #line 10 "fps/root-finding.hpp" template <typename mint> vector<mint> root_finding(const FormalPowerSeries<mint>& f) { using fps = FormalPowerSeries<mint>; long long p = mint::get_mod(); vector<mint> ans; if (p == 2) { for (int i = 0; i < 2; i++) { if (f.eval(i) == 0) ans.push_back(i); } return ans; } vector<fps> fs; fs.push_back(PolyGCD(mod_pow(p, fps{0, 1}, f) - fps{0, 1}, f)); mt19937_64 rng(58); while (!fs.empty()) { auto g = fs.back(); fs.pop_back(); if (g.size() == 2) ans.push_back(-g[0]); if (g.size() <= 2) continue; fps s = fps{(long long)(rng() % p), 1}; fps t = PolyGCD(mod_pow((p - 1) / 2, s, g) - fps{1}, g); fs.push_back(t); if (g.size() != t.size()) fs.push_back(g / t); } return ans; }