Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: 多項式/形式的冪級数ライブラリ
(fps/formal-power-series.hpp)

形式的冪級数

多項式/形式的冪級数の基本操作を行うライブラリ。

次の3つのファイルから成り、mod が $998244353$ のような NTT 素数の時は 1 番目と 2 番目を、 $10^9 + 7$ の時は 1 番目と 3 番目を include して使用する。

実装の一部はうしさんのライブラリを大きく参考にしました。感謝…

加算・減算

$\mathrm{O}(N)$

乗算

多項式 $f(x), g(x)$ に対して $h(x)=f(x)g(x)$ となる $h(x)$ を求める。これは FFT/NTT(FMT) で求められる。
計算量は $\deg(f) + \deg(g) = N$ として $\mathrm{O}(N \log N)$ である。

除算

(注:この項に書かれている除算は多項式としての除算である。)

$f(x) = g(x)q(x) + r(x)$ となる多項式 $q,r$ を求めたい。( $\deg(f) = n - 1, \deg(g) = m - 1, \deg(q) = n-m \geq 0$ とする。) ここで、

\[\mathrm{rev}(f) := f(x^{-1})\cdot x^{n-1}\]

のようにおくと $\mathrm{rev}(f)$ もまた $n-1$ 次の多項式になる。最初の式を $x \leftarrow x^{-1}$ に置き換えると

\[f(x^{-1}) = g(x^{-1})q(x^{-1})+r(x^{-1})\]

両辺に $x^{n-1}$ を掛けると

\[f(x^{-1}) x^{n-1} = g(x^{-1})x^{m-1}\cdot q(x^{-1})x^{n-m}+r(x^{-1})x^{m-2} \cdot x^{n-m+1}\] \[\leftrightarrow \mathrm{rev}(f) = \mathrm{rev}(g) \mathrm{rev}(q) + \mathrm{rev}(r) x^{n-m+1}\] \[\rightarrow \mathrm{rev}(f) \equiv \mathrm{rev}(g) \mathrm{rev}(q) \pmod{x^{n-m+1}}\]

を得る。$\mathrm{rev}(g)$ の定数項は非 $0$ なので $\text{mod }x^{n-m+1}$ 上で逆元を取ることができて (逆元の出し方は後述)

\[\frac{\mathrm{rev}(f)}{\mathrm{rev}(g)} \equiv \mathrm{rev}(q) \pmod{x^{n-m+1}}\]

となるのでこの式から $q$ を計算できて、$r$ もまた $f - gq$ から計算できる。
計算量は $\deg(f) + \deg(g) = N$ として $\mathrm{O}(N \log N)$ である。

微分・積分

$\mathrm{O}(N)$

ダブリング

$f \bmod{x^n}$ を求めたい時に、$f \equiv f_0 \pmod{x}$ から始めて精度を倍々にして求める手法のことをダブリングと呼ぶ。

具体的には、$\hat{f} \equiv f \pmod{x^k}$ から $f \bmod{x^{2k}}$ を

\[(f - \hat{f})^2 \equiv 0 \pmod{x^{2k}}\]

を利用して計算する。

逆元

$fg\equiv 1 \pmod{x^n}$ となる $f$ の逆元 $g$ をダブリングで求めたい。

まず $g \equiv f_0^{-1} \pmod{x}$ である。次に、$g \equiv \hat{g} \pmod{x^k}$ が分かっているとき $g \bmod{x^{2k}}$ を求める。

\[(g-\hat{g})^2\equiv g^2-2g\hat{g}+\hat{g}^2\equiv0 \pmod{x^{2k}}\]

両辺に $f$ を掛けて

\[fg^2-2fg\hat{g}+f\hat{g}^2\equiv0 \pmod{x^{2k}}\]

$fg \equiv 1 \pmod{x^{2k}}$ を利用して $fg$ を消すと

\[g\equiv2\hat{g}-f\hat{g}^2 \pmod{x^{2k}}\]

を得る。計算量は $T(n)=T(n/2)+\mathrm{O}(n \log n)$ を解いて $T(n)=\mathrm{O}(N \log N)$ となる (出力の次数を $N$ とする)

(なお、mod が NTT 素数の場合は実装を工夫することで定数倍が 2 倍以上軽くなる。)

log

$f_0 = 1$ を満たす $f$ に対して、$f \equiv e^g \mod x^n$ を満たす $g$ を $g\equiv\log f \pmod{x^n}$ として定義する。この時、$g$ は

\[\log f \equiv \int \frac{f'}{f} dx \pmod{x^n}\]

から求まる。(なお、上の式は一見すると不定積分の定数項のズレが心配に見えるが、$g_0 = 0$ とおくと $x=0$ を代入したとき両辺が等しくなるので問題ない。) 計算量は $\mathrm{O}(N \log N)$ である。(出力の次数を $N$ とする)

ニュートン法

ダブリングを利用した逆元の計算ではいささか唐突に $(g-\hat{g})^2$ を計算したが、このアルゴリズムを一般化した手法としてニュートン法を説明する。

数値解析におけるニュートン法とは、$f(x)=0$ を満たす $x$ の値を、漸化式

\[x_{n+1} = x_{n}-\frac{f(x_n)}{f'(x_n)}\]

を利用した反復計算により真の値を得るアルゴリズムであった。このアルゴリズムを形式的冪級数にも応用する。

$G(g) \equiv 0 \pmod{x^n}$ を満たす形式的冪級数 $g$ を求めたい。$G(\hat{g}) \equiv 0 \pmod{x^n}$ を満たす $\hat{g}$ が求まっている時、$\hat{g}$ と $g \bmod{x^{2n}}$ の間に成り立つ関係式を考える。

$G(g)$ の $g = \hat{g}$ におけるテイラー展開の式は

\[G(g) = G(\hat{g}) + G'(\hat{g})(g-\hat{g})+\mathrm{O}((g-\hat{g})^2)\]

となる。両辺 $\text{mod }x^{2n}$ を取ると ($g - \hat{g}$ の定数項は $0$ なので形式的冪級数とみなして考えても問題が起こらない)、$(g-\hat{g})^2 \equiv 0 \pmod{x ^{2n}}$ より

\[0 \equiv G(g) \equiv G(\hat{g}) + G'(\hat{g})(g-\hat{g}) \pmod{x^{2n}}\]

であり、これをさらに変形して

\[g \equiv \hat{g} - \frac{G(\hat{g})}{G'(\hat{g})} \pmod{x^{2n}}\]

を得る。この式を使って $\mathrm{exp}(f)$ や $\mathrm{sqrt}(f)$ を計算することが出来る。

exp

$g \equiv e^f \pmod{x^n}$ となる $g$ を求める。

\[\log g \equiv f \pmod{x^n}\]

等式が成り立つには $f_0=0$ が必要でこの時 $g_0=1$ である。次にニュートン法を使うと、

\[g \equiv \hat{g} - \frac{\log\hat{g}-f}{\log'\hat{g}} \pmod{x^{2n}}\]

$\log’ \hat{g}\equiv\frac{1}{\hat{g}}$ より

\[g\equiv \hat{g}(1-\log \hat{g}+f) \pmod{x^{2n}}\]

となり、この式を利用して長さを倍々にしていくことができる。計算量は $\mathrm{O}(N \log N)$ である。(出力の次数を $N$ とする)

累乗

$g \equiv f^k \pmod{x^n}$ となる $g$ を求めたい。繰り返し自乗法を用いると $\mathrm{O}(N \log k \log N)$ で求まるがexpを使うともっと早くなる。

$f$ の一番次数の低い項が $a_p x^p \ (p \neq 0)$ のときは $\left(\frac{f}{a_p x^p}\right)^k$ を計算して最後に $a_p^kx^{kp}$ を掛ければよいため、$f$ の定数項が $1$ である場合のみ考えればよい。このとき $f^k$ は $f^k\equiv e^{(\log f)k}$ から計算できる。
計算量は $\mathrm{O}(N \log N)$ である。(出力の次数を $N$ とする)

平方根

fps/fps-sqrt.hpp にて解説。

三角関数

fps/fps-circular.hpp にて解説。

平行移動

fps/taylor-shift.hpp にて解説。

階差(imos法)/累積和

(関数は未実装)
数列に対して階差を取る(いわゆる imos 法)、あるいはその逆に累積和を取る操作を形式的冪級数の数式的な操作に置き換えると、それぞれ $1-x$ を掛ける/割る操作に対応している。計算量はともに $\mathrm{O}(N)$ である。

Required by

Verified with

Code

#pragma once

template <typename mint>
struct FormalPowerSeries : vector<mint> {
  using vector<mint>::vector;
  using FPS = FormalPowerSeries;

  FPS &operator+=(const FPS &r) {
    if (r.size() > this->size()) this->resize(r.size());
    for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i];
    return *this;
  }

  FPS &operator+=(const mint &r) {
    if (this->empty()) this->resize(1);
    (*this)[0] += r;
    return *this;
  }

  FPS &operator-=(const FPS &r) {
    if (r.size() > this->size()) this->resize(r.size());
    for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i];
    return *this;
  }

  FPS &operator-=(const mint &r) {
    if (this->empty()) this->resize(1);
    (*this)[0] -= r;
    return *this;
  }

  FPS &operator*=(const mint &v) {
    for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v;
    return *this;
  }

  FPS &operator/=(const FPS &r) {
    if (this->size() < r.size()) {
      this->clear();
      return *this;
    }
    int n = this->size() - r.size() + 1;
    if ((int)r.size() <= 64) {
      FPS f(*this), g(r);
      g.shrink();
      mint coeff = g.back().inverse();
      for (auto &x : g) x *= coeff;
      int deg = (int)f.size() - (int)g.size() + 1;
      int gs = g.size();
      FPS quo(deg);
      for (int i = deg - 1; i >= 0; i--) {
        quo[i] = f[i + gs - 1];
        for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
      }
      *this = quo * coeff;
      this->resize(n, mint(0));
      return *this;
    }
    return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev();
  }

  FPS &operator%=(const FPS &r) {
    *this -= *this / r * r;
    shrink();
    return *this;
  }

  FPS operator+(const FPS &r) const { return FPS(*this) += r; }
  FPS operator+(const mint &v) const { return FPS(*this) += v; }
  FPS operator-(const FPS &r) const { return FPS(*this) -= r; }
  FPS operator-(const mint &v) const { return FPS(*this) -= v; }
  FPS operator*(const FPS &r) const { return FPS(*this) *= r; }
  FPS operator*(const mint &v) const { return FPS(*this) *= v; }
  FPS operator/(const FPS &r) const { return FPS(*this) /= r; }
  FPS operator%(const FPS &r) const { return FPS(*this) %= r; }
  FPS operator-() const {
    FPS ret(this->size());
    for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i];
    return ret;
  }

  void shrink() {
    while (this->size() && this->back() == mint(0)) this->pop_back();
  }

  FPS rev() const {
    FPS ret(*this);
    reverse(begin(ret), end(ret));
    return ret;
  }

  FPS dot(FPS r) const {
    FPS ret(min(this->size(), r.size()));
    for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i];
    return ret;
  }

  // 前 sz 項を取ってくる。sz に足りない項は 0 埋めする
  FPS pre(int sz) const {
    FPS ret(begin(*this), begin(*this) + min((int)this->size(), sz));
    if ((int)ret.size() < sz) ret.resize(sz);
    return ret;
  }

  FPS operator>>(int sz) const {
    if ((int)this->size() <= sz) return {};
    FPS ret(*this);
    ret.erase(ret.begin(), ret.begin() + sz);
    return ret;
  }

  FPS operator<<(int sz) const {
    FPS ret(*this);
    ret.insert(ret.begin(), sz, mint(0));
    return ret;
  }

  FPS diff() const {
    const int n = (int)this->size();
    FPS ret(max(0, n - 1));
    mint one(1), coeff(1);
    for (int i = 1; i < n; i++) {
      ret[i - 1] = (*this)[i] * coeff;
      coeff += one;
    }
    return ret;
  }

  FPS integral() const {
    const int n = (int)this->size();
    FPS ret(n + 1);
    ret[0] = mint(0);
    if (n > 0) ret[1] = mint(1);
    auto mod = mint::get_mod();
    for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i);
    for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i];
    return ret;
  }

  mint eval(mint x) const {
    mint r = 0, w = 1;
    for (auto &v : *this) r += w * v, w *= x;
    return r;
  }

  FPS log(int deg = -1) const {
    assert(!(*this).empty() && (*this)[0] == mint(1));
    if (deg == -1) deg = (int)this->size();
    return (this->diff() * this->inv(deg)).pre(deg - 1).integral();
  }

  FPS pow(int64_t k, int deg = -1) const {
    const int n = (int)this->size();
    if (deg == -1) deg = n;
    if (k == 0) {
      FPS ret(deg);
      if (deg) ret[0] = 1;
      return ret;
    }
    for (int i = 0; i < n; i++) {
      if ((*this)[i] != mint(0)) {
        mint rev = mint(1) / (*this)[i];
        FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg);
        ret *= (*this)[i].pow(k);
        ret = (ret << (i * k)).pre(deg);
        if ((int)ret.size() < deg) ret.resize(deg, mint(0));
        return ret;
      }
      if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0));
    }
    return FPS(deg, mint(0));
  }

  static void *ntt_ptr;
  static void set_fft();
  FPS &operator*=(const FPS &r);
  void ntt();
  void intt();
  void ntt_doubling();
  static int ntt_pr();
  FPS inv(int deg = -1) const;
  FPS exp(int deg = -1) const;
};
template <typename mint>
void *FormalPowerSeries<mint>::ntt_ptr = nullptr;

/**
 * @brief 多項式/形式的冪級数ライブラリ
 * @docs docs/fps/formal-power-series.md
 */
#line 2 "fps/formal-power-series.hpp"

template <typename mint>
struct FormalPowerSeries : vector<mint> {
  using vector<mint>::vector;
  using FPS = FormalPowerSeries;

  FPS &operator+=(const FPS &r) {
    if (r.size() > this->size()) this->resize(r.size());
    for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i];
    return *this;
  }

  FPS &operator+=(const mint &r) {
    if (this->empty()) this->resize(1);
    (*this)[0] += r;
    return *this;
  }

  FPS &operator-=(const FPS &r) {
    if (r.size() > this->size()) this->resize(r.size());
    for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i];
    return *this;
  }

  FPS &operator-=(const mint &r) {
    if (this->empty()) this->resize(1);
    (*this)[0] -= r;
    return *this;
  }

  FPS &operator*=(const mint &v) {
    for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v;
    return *this;
  }

  FPS &operator/=(const FPS &r) {
    if (this->size() < r.size()) {
      this->clear();
      return *this;
    }
    int n = this->size() - r.size() + 1;
    if ((int)r.size() <= 64) {
      FPS f(*this), g(r);
      g.shrink();
      mint coeff = g.back().inverse();
      for (auto &x : g) x *= coeff;
      int deg = (int)f.size() - (int)g.size() + 1;
      int gs = g.size();
      FPS quo(deg);
      for (int i = deg - 1; i >= 0; i--) {
        quo[i] = f[i + gs - 1];
        for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
      }
      *this = quo * coeff;
      this->resize(n, mint(0));
      return *this;
    }
    return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev();
  }

  FPS &operator%=(const FPS &r) {
    *this -= *this / r * r;
    shrink();
    return *this;
  }

  FPS operator+(const FPS &r) const { return FPS(*this) += r; }
  FPS operator+(const mint &v) const { return FPS(*this) += v; }
  FPS operator-(const FPS &r) const { return FPS(*this) -= r; }
  FPS operator-(const mint &v) const { return FPS(*this) -= v; }
  FPS operator*(const FPS &r) const { return FPS(*this) *= r; }
  FPS operator*(const mint &v) const { return FPS(*this) *= v; }
  FPS operator/(const FPS &r) const { return FPS(*this) /= r; }
  FPS operator%(const FPS &r) const { return FPS(*this) %= r; }
  FPS operator-() const {
    FPS ret(this->size());
    for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i];
    return ret;
  }

  void shrink() {
    while (this->size() && this->back() == mint(0)) this->pop_back();
  }

  FPS rev() const {
    FPS ret(*this);
    reverse(begin(ret), end(ret));
    return ret;
  }

  FPS dot(FPS r) const {
    FPS ret(min(this->size(), r.size()));
    for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i];
    return ret;
  }

  // 前 sz 項を取ってくる。sz に足りない項は 0 埋めする
  FPS pre(int sz) const {
    FPS ret(begin(*this), begin(*this) + min((int)this->size(), sz));
    if ((int)ret.size() < sz) ret.resize(sz);
    return ret;
  }

  FPS operator>>(int sz) const {
    if ((int)this->size() <= sz) return {};
    FPS ret(*this);
    ret.erase(ret.begin(), ret.begin() + sz);
    return ret;
  }

  FPS operator<<(int sz) const {
    FPS ret(*this);
    ret.insert(ret.begin(), sz, mint(0));
    return ret;
  }

  FPS diff() const {
    const int n = (int)this->size();
    FPS ret(max(0, n - 1));
    mint one(1), coeff(1);
    for (int i = 1; i < n; i++) {
      ret[i - 1] = (*this)[i] * coeff;
      coeff += one;
    }
    return ret;
  }

  FPS integral() const {
    const int n = (int)this->size();
    FPS ret(n + 1);
    ret[0] = mint(0);
    if (n > 0) ret[1] = mint(1);
    auto mod = mint::get_mod();
    for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i);
    for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i];
    return ret;
  }

  mint eval(mint x) const {
    mint r = 0, w = 1;
    for (auto &v : *this) r += w * v, w *= x;
    return r;
  }

  FPS log(int deg = -1) const {
    assert(!(*this).empty() && (*this)[0] == mint(1));
    if (deg == -1) deg = (int)this->size();
    return (this->diff() * this->inv(deg)).pre(deg - 1).integral();
  }

  FPS pow(int64_t k, int deg = -1) const {
    const int n = (int)this->size();
    if (deg == -1) deg = n;
    if (k == 0) {
      FPS ret(deg);
      if (deg) ret[0] = 1;
      return ret;
    }
    for (int i = 0; i < n; i++) {
      if ((*this)[i] != mint(0)) {
        mint rev = mint(1) / (*this)[i];
        FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg);
        ret *= (*this)[i].pow(k);
        ret = (ret << (i * k)).pre(deg);
        if ((int)ret.size() < deg) ret.resize(deg, mint(0));
        return ret;
      }
      if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0));
    }
    return FPS(deg, mint(0));
  }

  static void *ntt_ptr;
  static void set_fft();
  FPS &operator*=(const FPS &r);
  void ntt();
  void intt();
  void ntt_doubling();
  static int ntt_pr();
  FPS inv(int deg = -1) const;
  FPS exp(int deg = -1) const;
};
template <typename mint>
void *FormalPowerSeries<mint>::ntt_ptr = nullptr;

/**
 * @brief 多項式/形式的冪級数ライブラリ
 * @docs docs/fps/formal-power-series.md
 */
Back to top page