Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: Black Box Linear Algebra
(matrix/black-box-linear-algebra.hpp)

Depends on

Verified with

Code

#include "../fps/berlekamp-massey.hpp"
#include "../fps/formal-power-series.hpp"
#include "../fps/mod-pow.hpp"
#include "../misc/rng.hpp"
//
namespace BBLAImpl {

template <typename mint>
mint inner_product(const FormalPowerSeries<mint>& a,
                   const FormalPowerSeries<mint>& b) {
  mint res = 0;
  int n = a.size();
  assert(n == (int)b.size());
  for (int i = 0; i < n; i++) res += a[i] * b[i];
  return res;
}

template <typename mint>
FormalPowerSeries<mint> random_poly(int n) {
  FormalPowerSeries<mint> res(n);
  for (auto& x : res) x = randint(0, mint::get_mod());
  return res;
}

template <typename mint>
struct ModMatrix : vector<FormalPowerSeries<mint>> {
  using fps = FormalPowerSeries<mint>;

  ModMatrix(int n) : vector<fps>(n, fps(n)) {}

  inline void add(int i, int j, mint x) { (*this)[i][j] += x; }

  friend fps operator*(const ModMatrix& m, const fps& r) {
    int n = m.size();
    assert(n == (int)r.size());
    fps res(n);
    for (int i = 0; i < n; i++)
      for (int j = 0; j < n; j++) res[i] += m[i][j] * r[j];
    return res;
  }

  void apply(int i, mint r) {
    int n = (*this).size();
    for (int j = 0; j < n; j++) (*this)[i][j] *= r;
  }
};

template <typename mint>
struct SparseMatrix : vector<vector<pair<int, mint>>> {
  using fps = FormalPowerSeries<mint>;

  template <typename... Args>
  SparseMatrix(Args... args) : vector<vector<pair<int, mint>>>(args...) {}

  inline void add(int i, int j, mint x) { (*this)[i].emplace_back(j, x); }

  friend fps operator*(const SparseMatrix& m, const fps& r) {
    int n = m.size();
    assert(n == (int)r.size());
    fps res(n);
    for (int i = 0; i < n; i++)
      for (auto&& [j, x] : m[i]) res[i] += x * r[j];
    return res;
  }

  void apply(int i, mint r) {
    for (auto&& [_, x] : (*this)[i]) x *= r;
  }
};

template <typename mint>
FormalPowerSeries<mint> vector_minpoly(
    const vector<FormalPowerSeries<mint>>& b) {
  assert(!b.empty());
  int n = b.size(), m = b[0].size();
  FormalPowerSeries<mint> u = random_poly<mint>(m), a(n);
  for (int i = 0; i < n; i++) a[i] = inner_product(b[i], u);
  auto mp = BerlekampMassey<mint>(a);
  return {mp.begin(), mp.end()};
}

template <typename mint, typename Mat>
FormalPowerSeries<mint> mat_minpoly(const Mat& A) {
  int n = A.size();
  FormalPowerSeries<mint> u = random_poly<mint>(n);
  vector<FormalPowerSeries<mint>> b(n * 2 + 1);
  for (int i = 0; i < (int)b.size(); i++) b[i] = u, u = A * u;
  FormalPowerSeries<mint> mp = vector_minpoly(b);
  return mp;
}

// calculate A^k b
template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_pow(const Mat& A, FormalPowerSeries<mint> b,
                                 int64_t k) {
  using fps = FormalPowerSeries<mint>;
  int n = b.size();
  fps mp = mat_minpoly<mint, Mat>(A);
  fps c = mod_pow<mint>(k, fps{0, 1}, mp.rev());
  fps res(n);
  for (int i = 0; i < (int)c.size(); i++) res += b * c[i], b = A * b;
  return res;
}

template <typename mint, typename Mat>
mint fast_det(const Mat& A) {
  using fps = FormalPowerSeries<mint>;
  int n = A.size();
  fps D;
  while (true) {
    do {
      D = random_poly<mint>(n);
    } while (any_of(begin(D), end(D), [](mint x) { return x == mint(0); }));

    Mat AD = A;
    for (int i = 0; i < n; i++) AD.apply(i, D[i]);
    fps mp = mat_minpoly<mint, Mat>(AD);
    if (mp.back() == 0) return 0;
    if ((int)mp.size() != n + 1) continue;
    mint det = n & 1 ? -mp.back() : mp.back();
    mint Ddet = 1;
    for (auto& d : D) Ddet *= d;
    return det / Ddet;
  }
  exit(1);
}

template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_linear_equation(const Mat& A, const FormalPowerSeries<mint>& b) {
  using fps = FormalPowerSeries<mint>;
  int n = A.size();
  fps mp = mat_minpoly<mint, Mat>(A).rev();
  fps buf = b, res(n);
  for (int i = 1; i < (int)mp.size(); i++) {
    res = buf * mp[i];
    buf = A * buf;
  }
  return buf * mp[0].inverse();
}

}  // namespace BBLAImpl

using BBLAImpl::fast_det;
using BBLAImpl::fast_pow;
using BBLAImpl::ModMatrix;
using BBLAImpl::SparseMatrix;

/**
 * @brief Black Box Linear Algebra
 */
#line 1 "matrix/black-box-linear-algebra.hpp"

#line 2 "fps/berlekamp-massey.hpp"

template <typename mint>
vector<mint> BerlekampMassey(const vector<mint> &s) {
  const int N = (int)s.size();
  vector<mint> b, c;
  b.reserve(N + 1);
  c.reserve(N + 1);
  b.push_back(mint(1));
  c.push_back(mint(1));
  mint y = mint(1);
  for (int ed = 1; ed <= N; ed++) {
    int l = int(c.size()), m = int(b.size());
    mint x = 0;
    for (int i = 0; i < l; i++) x += c[i] * s[ed - l + i];
    b.emplace_back(mint(0));
    m++;
    if (x == mint(0)) continue;
    mint freq = x / y;
    if (l < m) {
      auto tmp = c;
      c.insert(begin(c), m - l, mint(0));
      for (int i = 0; i < m; i++) c[m - 1 - i] -= freq * b[m - 1 - i];
      b = tmp;
      y = x;
    } else {
      for (int i = 0; i < m; i++) c[l - 1 - i] -= freq * b[m - 1 - i];
    }
  }
  reverse(begin(c), end(c));
  return c;
}
#line 2 "fps/formal-power-series.hpp"

template <typename mint>
struct FormalPowerSeries : vector<mint> {
  using vector<mint>::vector;
  using FPS = FormalPowerSeries;

  FPS &operator+=(const FPS &r) {
    if (r.size() > this->size()) this->resize(r.size());
    for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i];
    return *this;
  }

  FPS &operator+=(const mint &r) {
    if (this->empty()) this->resize(1);
    (*this)[0] += r;
    return *this;
  }

  FPS &operator-=(const FPS &r) {
    if (r.size() > this->size()) this->resize(r.size());
    for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i];
    return *this;
  }

  FPS &operator-=(const mint &r) {
    if (this->empty()) this->resize(1);
    (*this)[0] -= r;
    return *this;
  }

  FPS &operator*=(const mint &v) {
    for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v;
    return *this;
  }

  FPS &operator/=(const FPS &r) {
    if (this->size() < r.size()) {
      this->clear();
      return *this;
    }
    int n = this->size() - r.size() + 1;
    if ((int)r.size() <= 64) {
      FPS f(*this), g(r);
      g.shrink();
      mint coeff = g.back().inverse();
      for (auto &x : g) x *= coeff;
      int deg = (int)f.size() - (int)g.size() + 1;
      int gs = g.size();
      FPS quo(deg);
      for (int i = deg - 1; i >= 0; i--) {
        quo[i] = f[i + gs - 1];
        for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
      }
      *this = quo * coeff;
      this->resize(n, mint(0));
      return *this;
    }
    return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev();
  }

  FPS &operator%=(const FPS &r) {
    *this -= *this / r * r;
    shrink();
    return *this;
  }

  FPS operator+(const FPS &r) const { return FPS(*this) += r; }
  FPS operator+(const mint &v) const { return FPS(*this) += v; }
  FPS operator-(const FPS &r) const { return FPS(*this) -= r; }
  FPS operator-(const mint &v) const { return FPS(*this) -= v; }
  FPS operator*(const FPS &r) const { return FPS(*this) *= r; }
  FPS operator*(const mint &v) const { return FPS(*this) *= v; }
  FPS operator/(const FPS &r) const { return FPS(*this) /= r; }
  FPS operator%(const FPS &r) const { return FPS(*this) %= r; }
  FPS operator-() const {
    FPS ret(this->size());
    for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i];
    return ret;
  }

  void shrink() {
    while (this->size() && this->back() == mint(0)) this->pop_back();
  }

  FPS rev() const {
    FPS ret(*this);
    reverse(begin(ret), end(ret));
    return ret;
  }

  FPS dot(FPS r) const {
    FPS ret(min(this->size(), r.size()));
    for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i];
    return ret;
  }

  // 前 sz 項を取ってくる。sz に足りない項は 0 埋めする
  FPS pre(int sz) const {
    FPS ret(begin(*this), begin(*this) + min((int)this->size(), sz));
    if ((int)ret.size() < sz) ret.resize(sz);
    return ret;
  }

  FPS operator>>(int sz) const {
    if ((int)this->size() <= sz) return {};
    FPS ret(*this);
    ret.erase(ret.begin(), ret.begin() + sz);
    return ret;
  }

  FPS operator<<(int sz) const {
    FPS ret(*this);
    ret.insert(ret.begin(), sz, mint(0));
    return ret;
  }

  FPS diff() const {
    const int n = (int)this->size();
    FPS ret(max(0, n - 1));
    mint one(1), coeff(1);
    for (int i = 1; i < n; i++) {
      ret[i - 1] = (*this)[i] * coeff;
      coeff += one;
    }
    return ret;
  }

  FPS integral() const {
    const int n = (int)this->size();
    FPS ret(n + 1);
    ret[0] = mint(0);
    if (n > 0) ret[1] = mint(1);
    auto mod = mint::get_mod();
    for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i);
    for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i];
    return ret;
  }

  mint eval(mint x) const {
    mint r = 0, w = 1;
    for (auto &v : *this) r += w * v, w *= x;
    return r;
  }

  FPS log(int deg = -1) const {
    assert(!(*this).empty() && (*this)[0] == mint(1));
    if (deg == -1) deg = (int)this->size();
    return (this->diff() * this->inv(deg)).pre(deg - 1).integral();
  }

  FPS pow(int64_t k, int deg = -1) const {
    const int n = (int)this->size();
    if (deg == -1) deg = n;
    if (k == 0) {
      FPS ret(deg);
      if (deg) ret[0] = 1;
      return ret;
    }
    for (int i = 0; i < n; i++) {
      if ((*this)[i] != mint(0)) {
        mint rev = mint(1) / (*this)[i];
        FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg);
        ret *= (*this)[i].pow(k);
        ret = (ret << (i * k)).pre(deg);
        if ((int)ret.size() < deg) ret.resize(deg, mint(0));
        return ret;
      }
      if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0));
    }
    return FPS(deg, mint(0));
  }

  static void *ntt_ptr;
  static void set_fft();
  FPS &operator*=(const FPS &r);
  void ntt();
  void intt();
  void ntt_doubling();
  static int ntt_pr();
  FPS inv(int deg = -1) const;
  FPS exp(int deg = -1) const;
};
template <typename mint>
void *FormalPowerSeries<mint>::ntt_ptr = nullptr;

/**
 * @brief 多項式/形式的冪級数ライブラリ
 * @docs docs/fps/formal-power-series.md
 */
#line 2 "fps/mod-pow.hpp"

#line 4 "fps/mod-pow.hpp"

template <typename mint>
FormalPowerSeries<mint> mod_pow(int64_t k, const FormalPowerSeries<mint>& base,
                                const FormalPowerSeries<mint>& d) {
  using fps = FormalPowerSeries<mint>;
  assert(!d.empty());
  auto inv = d.rev().inv();
  auto quo = [&](const fps& poly) {
    if (poly.size() < d.size()) return fps{};
    int n = poly.size() - d.size() + 1;
    return (poly.rev().pre(n) * inv.pre(n)).pre(n).rev();
  };
  fps res{1}, b(base);
  while (k) {
    if (k & 1) {
      res *= b;
      res -= quo(res) * d;
      res.shrink();
    }
    b *= b;
    b -= quo(b) * d;
    b.shrink();
    k >>= 1;
    assert(b.size() + 1 <= d.size());
    assert(res.size() + 1 <= d.size());
  }
  return res;
}

/**
 * @brief Mod-Pow ($f(x)^k \mod g(x)$)
 */
#line 2 "misc/rng.hpp"

#line 2 "internal/internal-seed.hpp"

#include <chrono>
using namespace std;

namespace internal {
unsigned long long non_deterministic_seed() {
  unsigned long long m =
      chrono::duration_cast<chrono::nanoseconds>(
          chrono::high_resolution_clock::now().time_since_epoch())
          .count();
  m ^= 9845834732710364265uLL;
  m ^= m << 24, m ^= m >> 31, m ^= m << 35;
  return m;
}
unsigned long long deterministic_seed() { return 88172645463325252UL; }

// 64 bit の seed 値を生成 (手元では seed 固定)
// 連続で呼び出すと同じ値が何度も返ってくるので注意
// #define RANDOMIZED_SEED するとシードがランダムになる
unsigned long long seed() {
#if defined(NyaanLocal) && !defined(RANDOMIZED_SEED)
  return deterministic_seed();
#else
  return non_deterministic_seed();
#endif
}

}  // namespace internal
#line 4 "misc/rng.hpp"

namespace my_rand {
using i64 = long long;
using u64 = unsigned long long;

// [0, 2^64 - 1)
u64 rng() {
  static u64 _x = internal::seed();
  return _x ^= _x << 7, _x ^= _x >> 9;
}

// [l, r]
i64 rng(i64 l, i64 r) {
  assert(l <= r);
  return l + rng() % u64(r - l + 1);
}

// [l, r)
i64 randint(i64 l, i64 r) {
  assert(l < r);
  return l + rng() % u64(r - l);
}

// choose n numbers from [l, r) without overlapping
vector<i64> randset(i64 l, i64 r, i64 n) {
  assert(l <= r && n <= r - l);
  unordered_set<i64> s;
  for (i64 i = n; i; --i) {
    i64 m = randint(l, r + 1 - i);
    if (s.find(m) != s.end()) m = r - i;
    s.insert(m);
  }
  vector<i64> ret;
  for (auto& x : s) ret.push_back(x);
  sort(begin(ret), end(ret));
  return ret;
}

// [0.0, 1.0)
double rnd() { return rng() * 5.42101086242752217004e-20; }
// [l, r)
double rnd(double l, double r) {
  assert(l < r);
  return l + rnd() * (r - l);
}

template <typename T>
void randshf(vector<T>& v) {
  int n = v.size();
  for (int i = 1; i < n; i++) swap(v[i], v[randint(0, i + 1)]);
}

}  // namespace my_rand

using my_rand::randint;
using my_rand::randset;
using my_rand::randshf;
using my_rand::rnd;
using my_rand::rng;
#line 6 "matrix/black-box-linear-algebra.hpp"
//
namespace BBLAImpl {

template <typename mint>
mint inner_product(const FormalPowerSeries<mint>& a,
                   const FormalPowerSeries<mint>& b) {
  mint res = 0;
  int n = a.size();
  assert(n == (int)b.size());
  for (int i = 0; i < n; i++) res += a[i] * b[i];
  return res;
}

template <typename mint>
FormalPowerSeries<mint> random_poly(int n) {
  FormalPowerSeries<mint> res(n);
  for (auto& x : res) x = randint(0, mint::get_mod());
  return res;
}

template <typename mint>
struct ModMatrix : vector<FormalPowerSeries<mint>> {
  using fps = FormalPowerSeries<mint>;

  ModMatrix(int n) : vector<fps>(n, fps(n)) {}

  inline void add(int i, int j, mint x) { (*this)[i][j] += x; }

  friend fps operator*(const ModMatrix& m, const fps& r) {
    int n = m.size();
    assert(n == (int)r.size());
    fps res(n);
    for (int i = 0; i < n; i++)
      for (int j = 0; j < n; j++) res[i] += m[i][j] * r[j];
    return res;
  }

  void apply(int i, mint r) {
    int n = (*this).size();
    for (int j = 0; j < n; j++) (*this)[i][j] *= r;
  }
};

template <typename mint>
struct SparseMatrix : vector<vector<pair<int, mint>>> {
  using fps = FormalPowerSeries<mint>;

  template <typename... Args>
  SparseMatrix(Args... args) : vector<vector<pair<int, mint>>>(args...) {}

  inline void add(int i, int j, mint x) { (*this)[i].emplace_back(j, x); }

  friend fps operator*(const SparseMatrix& m, const fps& r) {
    int n = m.size();
    assert(n == (int)r.size());
    fps res(n);
    for (int i = 0; i < n; i++)
      for (auto&& [j, x] : m[i]) res[i] += x * r[j];
    return res;
  }

  void apply(int i, mint r) {
    for (auto&& [_, x] : (*this)[i]) x *= r;
  }
};

template <typename mint>
FormalPowerSeries<mint> vector_minpoly(
    const vector<FormalPowerSeries<mint>>& b) {
  assert(!b.empty());
  int n = b.size(), m = b[0].size();
  FormalPowerSeries<mint> u = random_poly<mint>(m), a(n);
  for (int i = 0; i < n; i++) a[i] = inner_product(b[i], u);
  auto mp = BerlekampMassey<mint>(a);
  return {mp.begin(), mp.end()};
}

template <typename mint, typename Mat>
FormalPowerSeries<mint> mat_minpoly(const Mat& A) {
  int n = A.size();
  FormalPowerSeries<mint> u = random_poly<mint>(n);
  vector<FormalPowerSeries<mint>> b(n * 2 + 1);
  for (int i = 0; i < (int)b.size(); i++) b[i] = u, u = A * u;
  FormalPowerSeries<mint> mp = vector_minpoly(b);
  return mp;
}

// calculate A^k b
template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_pow(const Mat& A, FormalPowerSeries<mint> b,
                                 int64_t k) {
  using fps = FormalPowerSeries<mint>;
  int n = b.size();
  fps mp = mat_minpoly<mint, Mat>(A);
  fps c = mod_pow<mint>(k, fps{0, 1}, mp.rev());
  fps res(n);
  for (int i = 0; i < (int)c.size(); i++) res += b * c[i], b = A * b;
  return res;
}

template <typename mint, typename Mat>
mint fast_det(const Mat& A) {
  using fps = FormalPowerSeries<mint>;
  int n = A.size();
  fps D;
  while (true) {
    do {
      D = random_poly<mint>(n);
    } while (any_of(begin(D), end(D), [](mint x) { return x == mint(0); }));

    Mat AD = A;
    for (int i = 0; i < n; i++) AD.apply(i, D[i]);
    fps mp = mat_minpoly<mint, Mat>(AD);
    if (mp.back() == 0) return 0;
    if ((int)mp.size() != n + 1) continue;
    mint det = n & 1 ? -mp.back() : mp.back();
    mint Ddet = 1;
    for (auto& d : D) Ddet *= d;
    return det / Ddet;
  }
  exit(1);
}

template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_linear_equation(const Mat& A, const FormalPowerSeries<mint>& b) {
  using fps = FormalPowerSeries<mint>;
  int n = A.size();
  fps mp = mat_minpoly<mint, Mat>(A).rev();
  fps buf = b, res(n);
  for (int i = 1; i < (int)mp.size(); i++) {
    res = buf * mp[i];
    buf = A * buf;
  }
  return buf * mp[0].inverse();
}

}  // namespace BBLAImpl

using BBLAImpl::fast_det;
using BBLAImpl::fast_pow;
using BBLAImpl::ModMatrix;
using BBLAImpl::SparseMatrix;

/**
 * @brief Black Box Linear Algebra
 */
Back to top page