#include "fps/fft2d.hpp"
#pragma once #include <cassert> #include <vector> using namespace std; #include "formal-power-series.hpp" template <typename mint> void fft2d(vector<FormalPowerSeries<mint>>& a) { int H = a.size(), W = a[0].size(); assert((H & (H - 1)) == 0); assert((W & (W - 1)) == 0); for (int i = 0; i < H; i++) { bool ok = false; for (auto& x : a[i]) { if (x != mint()) { ok = true; break; } } if (ok) a[i].ntt(); } FormalPowerSeries<mint> buf(H); for (int i = 0; i < W; i++) { for (int j = 0; j < H; j++) buf[j] = a[j][i]; buf.ntt(); for (int j = 0; j < H; j++) a[j][i] = buf[j]; } } template <typename mint> void ifft2d(vector<FormalPowerSeries<mint>>& a) { int H = a.size(), W = a[0].size(); assert((H & (H - 1)) == 0); assert((W & (W - 1)) == 0); FormalPowerSeries<mint> buf(H); for (int i = 0; i < W; i++) { for (int j = 0; j < H; j++) buf[j] = a[j][i]; buf.intt(); for (int j = 0; j < H; j++) a[j][i] = buf[j]; } for (int i = 0; i < H; i++) { bool ok = false; for (auto& x : a[i]) { if (x != mint()) { ok = true; break; } } if (ok) a[i].intt(); } } template <typename mint> vector<FormalPowerSeries<mint>> transpose(vector<FormalPowerSeries<mint>> f) { int H = f.size(), W = f[0].size(); for (auto& v : f) assert((int)v.size() == W); vector<FormalPowerSeries<mint>> g(W, FormalPowerSeries<mint>(H)); for (int i = 0; i < H; i++) { for (int j = 0; j < W; j++) g[j][i] = f[i][j]; } return g; }; template <typename mint> vector<FormalPowerSeries<mint>> multiply2d_naive( vector<FormalPowerSeries<mint>> a, vector<FormalPowerSeries<mint>> b) { using fps = FormalPowerSeries<mint>; using fps2d = vector<fps>; if (a.empty() or b.empty()) return {}; if (a[0].empty() or b[0].empty()) return {}; int Ha = a.size(), Wa = a[0].size(); int Hb = b.size(), Wb = b[0].size(); for (auto& v : a) assert((int)v.size() == Wa); for (auto& v : b) assert((int)v.size() == Wb); fps2d c(Ha + Hb - 1, fps(Wa + Wb - 1)); for (int ia = 0; ia < Ha; ia++) { for (int ja = 0; ja < Wa; ja++) { for (int ib = 0; ib < Hb; ib++) { for (int jb = 0; jb < Wb; jb++) { c[ia + ib][ja + jb] += a[ia][ja] * b[ib][jb]; } } } } return c; } template <typename mint> vector<FormalPowerSeries<mint>> multiply2d_partially_naive( vector<FormalPowerSeries<mint>> a, vector<FormalPowerSeries<mint>> b) { using fps = FormalPowerSeries<mint>; using fps2d = vector<fps>; if (a.empty() or b.empty()) return {}; if (a[0].empty() or b[0].empty()) return {}; int Ha = a.size(), Wa = a[0].size(); int Hb = b.size(), Wb = b[0].size(); for (auto& v : a) assert((int)v.size() == Wa); for (auto& v : b) assert((int)v.size() == Wb); if (min(Ha, Hb) * min(Wa, Wb) <= 40) { return multiply2d_naive(a, b); } int W = 1; while (W < Wa + Wb - 1) W *= 2; if (W >= 64 and Wa + Wb - 1 <= W / 2 + 20) { if (Wa <= 20) swap(a, b), swap(Ha, Hb), swap(Wa, Wb); int d = Wa + Wb - 1 - W / 2; fps2d a1(Ha), a2(Ha); for (int i = 0; i < Ha; i++) { a1[i] = fps{begin(a[i]), end(a[i]) - d}; a2[i] = fps{end(a[i]) - d, end(a[i])}; } fps2d c1 = multiply2d_partially_naive(a1, b); fps2d c2 = multiply2d_partially_naive(a2, b); for (int i = 0; i < Ha + Hb - 1; i++) { c1[i] += c2[i] << (Wa - d); c1[i].resize(Wa + Wb - 1); } return c1; } for (auto& v : a) v.resize(W), v.ntt(); for (auto& v : b) v.resize(W), v.ntt(); fps2d cT; for (int j = 0; j < W; j++) { fps bufa(Ha), bufb(Hb); for (int i = 0; i < Ha; i++) bufa[i] = a[i][j]; for (int i = 0; i < Hb; i++) bufb[i] = b[i][j]; cT.push_back(bufa * bufb); } fps2d c = transpose(cT); for (auto& v : c) v.intt(), v.resize(Wa + Wb - 1); return c; } template <typename mint> vector<FormalPowerSeries<mint>> multiply2d(vector<FormalPowerSeries<mint>> a, vector<FormalPowerSeries<mint>> b) { using fps = FormalPowerSeries<mint>; using fps2d = vector<fps>; if (a.empty() or b.empty()) return {}; if (a[0].empty() or b[0].empty()) return {}; int Ha = a.size(), Wa = a[0].size(); int Hb = b.size(), Wb = b[0].size(); for (auto& v : a) assert((int)v.size() == Wa); for (auto& v : b) assert((int)v.size() == Wb); if (min(Ha, Hb) * min(Wa, Wb) <= 40) { return multiply2d_naive(a, b); } if (min(Ha, Hb) <= 40) { return multiply2d_partially_naive(a, b); } if (min(Wa, Wb) <= 40) { auto aT = transpose(a), bT = transpose(b); auto cT = multiply2d_partially_naive(aT, bT); return transpose(cT); } int H = 1, W = 1; while (H < Ha + Hb - 1) H *= 2; while (W < Wa + Wb - 1) W *= 2; if (Wa + Wb - 1 < W / 2 + 20) { int d = Wa + Wb - 1 - W / 2; fps2d a1(Ha), a2(Ha); for (int i = 0; i < Ha; i++) { a1[i] = fps{begin(a[i]), end(a[i]) - d}; a2[i] = fps{end(a[i]) - d, end(a[i])}; } fps2d c1 = multiply2d(a1, b); fps2d c2 = multiply2d(a2, b); for (int i = 0; i < Ha + Hb - 1; i++) { c1[i] += c2[i] << (Wa - d); c1[i].resize(Wa + Wb - 1); } return c1; } if (Ha + Hb - 1 < H / 2 + 20) { auto aT = transpose(a), bT = transpose(b); auto cT = multiply2d(aT, bT); return transpose(cT); } a.resize(H), b.resize(H); for (auto& v : a) v.resize(W); for (auto& v : b) v.resize(W); fft2d(a), fft2d(b); for (int i = 0; i < H; i++) { for (int j = 0; j < W; j++) a[i][j] *= b[i][j]; } ifft2d(a); a.resize(Ha + Hb - 1); for (auto& v : a) v.resize(Wa + Wb - 1); return a; }
#line 2 "fps/fft2d.hpp" #include <cassert> #include <vector> using namespace std; #line 2 "fps/formal-power-series.hpp" template <typename mint> struct FormalPowerSeries : vector<mint> { using vector<mint>::vector; using FPS = FormalPowerSeries; FPS &operator+=(const FPS &r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i]; return *this; } FPS &operator+=(const mint &r) { if (this->empty()) this->resize(1); (*this)[0] += r; return *this; } FPS &operator-=(const FPS &r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i]; return *this; } FPS &operator-=(const mint &r) { if (this->empty()) this->resize(1); (*this)[0] -= r; return *this; } FPS &operator*=(const mint &v) { for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v; return *this; } FPS &operator/=(const FPS &r) { if (this->size() < r.size()) { this->clear(); return *this; } int n = this->size() - r.size() + 1; if ((int)r.size() <= 64) { FPS f(*this), g(r); g.shrink(); mint coeff = g.back().inverse(); for (auto &x : g) x *= coeff; int deg = (int)f.size() - (int)g.size() + 1; int gs = g.size(); FPS quo(deg); for (int i = deg - 1; i >= 0; i--) { quo[i] = f[i + gs - 1]; for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j]; } *this = quo * coeff; this->resize(n, mint(0)); return *this; } return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev(); } FPS &operator%=(const FPS &r) { *this -= *this / r * r; shrink(); return *this; } FPS operator+(const FPS &r) const { return FPS(*this) += r; } FPS operator+(const mint &v) const { return FPS(*this) += v; } FPS operator-(const FPS &r) const { return FPS(*this) -= r; } FPS operator-(const mint &v) const { return FPS(*this) -= v; } FPS operator*(const FPS &r) const { return FPS(*this) *= r; } FPS operator*(const mint &v) const { return FPS(*this) *= v; } FPS operator/(const FPS &r) const { return FPS(*this) /= r; } FPS operator%(const FPS &r) const { return FPS(*this) %= r; } FPS operator-() const { FPS ret(this->size()); for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i]; return ret; } void shrink() { while (this->size() && this->back() == mint(0)) this->pop_back(); } FPS rev() const { FPS ret(*this); reverse(begin(ret), end(ret)); return ret; } FPS dot(FPS r) const { FPS ret(min(this->size(), r.size())); for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i]; return ret; } // 前 sz 項を取ってくる。sz に足りない項は 0 埋めする FPS pre(int sz) const { FPS ret(begin(*this), begin(*this) + min((int)this->size(), sz)); if ((int)ret.size() < sz) ret.resize(sz); return ret; } FPS operator>>(int sz) const { if ((int)this->size() <= sz) return {}; FPS ret(*this); ret.erase(ret.begin(), ret.begin() + sz); return ret; } FPS operator<<(int sz) const { FPS ret(*this); ret.insert(ret.begin(), sz, mint(0)); return ret; } FPS diff() const { const int n = (int)this->size(); FPS ret(max(0, n - 1)); mint one(1), coeff(1); for (int i = 1; i < n; i++) { ret[i - 1] = (*this)[i] * coeff; coeff += one; } return ret; } FPS integral() const { const int n = (int)this->size(); FPS ret(n + 1); ret[0] = mint(0); if (n > 0) ret[1] = mint(1); auto mod = mint::get_mod(); for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i); for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i]; return ret; } mint eval(mint x) const { mint r = 0, w = 1; for (auto &v : *this) r += w * v, w *= x; return r; } FPS log(int deg = -1) const { assert(!(*this).empty() && (*this)[0] == mint(1)); if (deg == -1) deg = (int)this->size(); return (this->diff() * this->inv(deg)).pre(deg - 1).integral(); } FPS pow(int64_t k, int deg = -1) const { const int n = (int)this->size(); if (deg == -1) deg = n; if (k == 0) { FPS ret(deg); if (deg) ret[0] = 1; return ret; } for (int i = 0; i < n; i++) { if ((*this)[i] != mint(0)) { mint rev = mint(1) / (*this)[i]; FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg); ret *= (*this)[i].pow(k); ret = (ret << (i * k)).pre(deg); if ((int)ret.size() < deg) ret.resize(deg, mint(0)); return ret; } if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0)); } return FPS(deg, mint(0)); } static void *ntt_ptr; static void set_fft(); FPS &operator*=(const FPS &r); void ntt(); void intt(); void ntt_doubling(); static int ntt_pr(); FPS inv(int deg = -1) const; FPS exp(int deg = -1) const; }; template <typename mint> void *FormalPowerSeries<mint>::ntt_ptr = nullptr; /** * @brief 多項式/形式的冪級数ライブラリ * @docs docs/fps/formal-power-series.md */ #line 8 "fps/fft2d.hpp" template <typename mint> void fft2d(vector<FormalPowerSeries<mint>>& a) { int H = a.size(), W = a[0].size(); assert((H & (H - 1)) == 0); assert((W & (W - 1)) == 0); for (int i = 0; i < H; i++) { bool ok = false; for (auto& x : a[i]) { if (x != mint()) { ok = true; break; } } if (ok) a[i].ntt(); } FormalPowerSeries<mint> buf(H); for (int i = 0; i < W; i++) { for (int j = 0; j < H; j++) buf[j] = a[j][i]; buf.ntt(); for (int j = 0; j < H; j++) a[j][i] = buf[j]; } } template <typename mint> void ifft2d(vector<FormalPowerSeries<mint>>& a) { int H = a.size(), W = a[0].size(); assert((H & (H - 1)) == 0); assert((W & (W - 1)) == 0); FormalPowerSeries<mint> buf(H); for (int i = 0; i < W; i++) { for (int j = 0; j < H; j++) buf[j] = a[j][i]; buf.intt(); for (int j = 0; j < H; j++) a[j][i] = buf[j]; } for (int i = 0; i < H; i++) { bool ok = false; for (auto& x : a[i]) { if (x != mint()) { ok = true; break; } } if (ok) a[i].intt(); } } template <typename mint> vector<FormalPowerSeries<mint>> transpose(vector<FormalPowerSeries<mint>> f) { int H = f.size(), W = f[0].size(); for (auto& v : f) assert((int)v.size() == W); vector<FormalPowerSeries<mint>> g(W, FormalPowerSeries<mint>(H)); for (int i = 0; i < H; i++) { for (int j = 0; j < W; j++) g[j][i] = f[i][j]; } return g; }; template <typename mint> vector<FormalPowerSeries<mint>> multiply2d_naive( vector<FormalPowerSeries<mint>> a, vector<FormalPowerSeries<mint>> b) { using fps = FormalPowerSeries<mint>; using fps2d = vector<fps>; if (a.empty() or b.empty()) return {}; if (a[0].empty() or b[0].empty()) return {}; int Ha = a.size(), Wa = a[0].size(); int Hb = b.size(), Wb = b[0].size(); for (auto& v : a) assert((int)v.size() == Wa); for (auto& v : b) assert((int)v.size() == Wb); fps2d c(Ha + Hb - 1, fps(Wa + Wb - 1)); for (int ia = 0; ia < Ha; ia++) { for (int ja = 0; ja < Wa; ja++) { for (int ib = 0; ib < Hb; ib++) { for (int jb = 0; jb < Wb; jb++) { c[ia + ib][ja + jb] += a[ia][ja] * b[ib][jb]; } } } } return c; } template <typename mint> vector<FormalPowerSeries<mint>> multiply2d_partially_naive( vector<FormalPowerSeries<mint>> a, vector<FormalPowerSeries<mint>> b) { using fps = FormalPowerSeries<mint>; using fps2d = vector<fps>; if (a.empty() or b.empty()) return {}; if (a[0].empty() or b[0].empty()) return {}; int Ha = a.size(), Wa = a[0].size(); int Hb = b.size(), Wb = b[0].size(); for (auto& v : a) assert((int)v.size() == Wa); for (auto& v : b) assert((int)v.size() == Wb); if (min(Ha, Hb) * min(Wa, Wb) <= 40) { return multiply2d_naive(a, b); } int W = 1; while (W < Wa + Wb - 1) W *= 2; if (W >= 64 and Wa + Wb - 1 <= W / 2 + 20) { if (Wa <= 20) swap(a, b), swap(Ha, Hb), swap(Wa, Wb); int d = Wa + Wb - 1 - W / 2; fps2d a1(Ha), a2(Ha); for (int i = 0; i < Ha; i++) { a1[i] = fps{begin(a[i]), end(a[i]) - d}; a2[i] = fps{end(a[i]) - d, end(a[i])}; } fps2d c1 = multiply2d_partially_naive(a1, b); fps2d c2 = multiply2d_partially_naive(a2, b); for (int i = 0; i < Ha + Hb - 1; i++) { c1[i] += c2[i] << (Wa - d); c1[i].resize(Wa + Wb - 1); } return c1; } for (auto& v : a) v.resize(W), v.ntt(); for (auto& v : b) v.resize(W), v.ntt(); fps2d cT; for (int j = 0; j < W; j++) { fps bufa(Ha), bufb(Hb); for (int i = 0; i < Ha; i++) bufa[i] = a[i][j]; for (int i = 0; i < Hb; i++) bufb[i] = b[i][j]; cT.push_back(bufa * bufb); } fps2d c = transpose(cT); for (auto& v : c) v.intt(), v.resize(Wa + Wb - 1); return c; } template <typename mint> vector<FormalPowerSeries<mint>> multiply2d(vector<FormalPowerSeries<mint>> a, vector<FormalPowerSeries<mint>> b) { using fps = FormalPowerSeries<mint>; using fps2d = vector<fps>; if (a.empty() or b.empty()) return {}; if (a[0].empty() or b[0].empty()) return {}; int Ha = a.size(), Wa = a[0].size(); int Hb = b.size(), Wb = b[0].size(); for (auto& v : a) assert((int)v.size() == Wa); for (auto& v : b) assert((int)v.size() == Wb); if (min(Ha, Hb) * min(Wa, Wb) <= 40) { return multiply2d_naive(a, b); } if (min(Ha, Hb) <= 40) { return multiply2d_partially_naive(a, b); } if (min(Wa, Wb) <= 40) { auto aT = transpose(a), bT = transpose(b); auto cT = multiply2d_partially_naive(aT, bT); return transpose(cT); } int H = 1, W = 1; while (H < Ha + Hb - 1) H *= 2; while (W < Wa + Wb - 1) W *= 2; if (Wa + Wb - 1 < W / 2 + 20) { int d = Wa + Wb - 1 - W / 2; fps2d a1(Ha), a2(Ha); for (int i = 0; i < Ha; i++) { a1[i] = fps{begin(a[i]), end(a[i]) - d}; a2[i] = fps{end(a[i]) - d, end(a[i])}; } fps2d c1 = multiply2d(a1, b); fps2d c2 = multiply2d(a2, b); for (int i = 0; i < Ha + Hb - 1; i++) { c1[i] += c2[i] << (Wa - d); c1[i].resize(Wa + Wb - 1); } return c1; } if (Ha + Hb - 1 < H / 2 + 20) { auto aT = transpose(a), bT = transpose(b); auto cT = multiply2d(aT, bT); return transpose(cT); } a.resize(H), b.resize(H); for (auto& v : a) v.resize(W); for (auto& v : b) v.resize(W); fft2d(a), fft2d(b); for (int i = 0; i < H; i++) { for (int j = 0; j < W; j++) a[i][j] *= b[i][j]; } ifft2d(a); a.resize(Ha + Hb - 1); for (auto& v : a) v.resize(Wa + Wb - 1); return a; }