Black Box Linear Algebra
(matrix/black-box-linear-algebra.hpp)
Depends on
Verified with
Code
#include "../fps/berlekamp-massey.hpp"
#include "../fps/formal-power-series.hpp"
#include "../fps/mod-pow.hpp"
#include "../misc/rng.hpp"
//
namespace BBLAImpl {
template <typename mint>
mint inner_product(const FormalPowerSeries<mint>& a,
const FormalPowerSeries<mint>& b) {
mint res = 0;
int n = a.size();
assert(n == (int)b.size());
for (int i = 0; i < n; i++) res += a[i] * b[i];
return res;
}
template <typename mint>
FormalPowerSeries<mint> random_poly(int n) {
FormalPowerSeries<mint> res(n);
for (auto& x : res) x = randint(0, mint::get_mod());
return res;
}
template <typename mint>
struct ModMatrix : vector<FormalPowerSeries<mint>> {
using fps = FormalPowerSeries<mint>;
ModMatrix(int n) : vector<fps>(n, fps(n)) {}
inline void add(int i, int j, mint x) { (*this)[i][j] += x; }
friend fps operator*(const ModMatrix& m, const fps& r) {
int n = m.size();
assert(n == (int)r.size());
fps res(n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++) res[i] += m[i][j] * r[j];
return res;
}
void apply(int i, mint r) {
int n = (*this).size();
for (int j = 0; j < n; j++) (*this)[i][j] *= r;
}
};
template <typename mint>
struct SparseMatrix : vector<vector<pair<int, mint>>> {
using fps = FormalPowerSeries<mint>;
template <typename... Args>
SparseMatrix(Args... args) : vector<vector<pair<int, mint>>>(args...) {}
inline void add(int i, int j, mint x) { (*this)[i].emplace_back(j, x); }
friend fps operator*(const SparseMatrix& m, const fps& r) {
int n = m.size();
assert(n == (int)r.size());
fps res(n);
for (int i = 0; i < n; i++)
for (auto&& [j, x] : m[i]) res[i] += x * r[j];
return res;
}
void apply(int i, mint r) {
for (auto&& [_, x] : (*this)[i]) x *= r;
}
};
template <typename mint>
FormalPowerSeries<mint> vector_minpoly(
const vector<FormalPowerSeries<mint>>& b) {
assert(!b.empty());
int n = b.size(), m = b[0].size();
FormalPowerSeries<mint> u = random_poly<mint>(m), a(n);
for (int i = 0; i < n; i++) a[i] = inner_product(b[i], u);
auto mp = BerlekampMassey<mint>(a);
return {mp.begin(), mp.end()};
}
template <typename mint, typename Mat>
FormalPowerSeries<mint> mat_minpoly(const Mat& A) {
int n = A.size();
FormalPowerSeries<mint> u = random_poly<mint>(n);
vector<FormalPowerSeries<mint>> b(n * 2 + 1);
for (int i = 0; i < (int)b.size(); i++) b[i] = u, u = A * u;
FormalPowerSeries<mint> mp = vector_minpoly(b);
return mp;
}
// calculate A^k b
template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_pow(const Mat& A, FormalPowerSeries<mint> b,
int64_t k) {
using fps = FormalPowerSeries<mint>;
int n = b.size();
fps mp = mat_minpoly<mint, Mat>(A);
fps c = mod_pow<mint>(k, fps{0, 1}, mp.rev());
fps res(n);
for (int i = 0; i < (int)c.size(); i++) res += b * c[i], b = A * b;
return res;
}
template <typename mint, typename Mat>
mint fast_det(const Mat& A) {
using fps = FormalPowerSeries<mint>;
int n = A.size();
fps D;
while (true) {
do {
D = random_poly<mint>(n);
} while (any_of(begin(D), end(D), [](mint x) { return x == mint(0); }));
Mat AD = A;
for (int i = 0; i < n; i++) AD.apply(i, D[i]);
fps mp = mat_minpoly<mint, Mat>(AD);
if (mp.back() == 0) return 0;
if ((int)mp.size() != n + 1) continue;
mint det = n & 1 ? -mp.back() : mp.back();
mint Ddet = 1;
for (auto& d : D) Ddet *= d;
return det / Ddet;
}
exit(1);
}
template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_linear_equation(const Mat& A, const FormalPowerSeries<mint>& b) {
using fps = FormalPowerSeries<mint>;
int n = A.size();
fps mp = mat_minpoly<mint, Mat>(A).rev();
fps buf = b, res(n);
for (int i = 1; i < (int)mp.size(); i++) {
res = buf * mp[i];
buf = A * buf;
}
return buf * mp[0].inverse();
}
} // namespace BBLAImpl
using BBLAImpl::fast_det;
using BBLAImpl::fast_pow;
using BBLAImpl::ModMatrix;
using BBLAImpl::SparseMatrix;
/**
* @brief Black Box Linear Algebra
*/
#line 1 "matrix/black-box-linear-algebra.hpp"
#line 2 "fps/berlekamp-massey.hpp"
template <typename mint>
vector<mint> BerlekampMassey(const vector<mint> &s) {
const int N = (int)s.size();
vector<mint> b, c;
b.reserve(N + 1);
c.reserve(N + 1);
b.push_back(mint(1));
c.push_back(mint(1));
mint y = mint(1);
for (int ed = 1; ed <= N; ed++) {
int l = int(c.size()), m = int(b.size());
mint x = 0;
for (int i = 0; i < l; i++) x += c[i] * s[ed - l + i];
b.emplace_back(mint(0));
m++;
if (x == mint(0)) continue;
mint freq = x / y;
if (l < m) {
auto tmp = c;
c.insert(begin(c), m - l, mint(0));
for (int i = 0; i < m; i++) c[m - 1 - i] -= freq * b[m - 1 - i];
b = tmp;
y = x;
} else {
for (int i = 0; i < m; i++) c[l - 1 - i] -= freq * b[m - 1 - i];
}
}
reverse(begin(c), end(c));
return c;
}
#line 2 "fps/formal-power-series.hpp"
template <typename mint>
struct FormalPowerSeries : vector<mint> {
using vector<mint>::vector;
using FPS = FormalPowerSeries;
FPS &operator+=(const FPS &r) {
if (r.size() > this->size()) this->resize(r.size());
for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i];
return *this;
}
FPS &operator+=(const mint &r) {
if (this->empty()) this->resize(1);
(*this)[0] += r;
return *this;
}
FPS &operator-=(const FPS &r) {
if (r.size() > this->size()) this->resize(r.size());
for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i];
return *this;
}
FPS &operator-=(const mint &r) {
if (this->empty()) this->resize(1);
(*this)[0] -= r;
return *this;
}
FPS &operator*=(const mint &v) {
for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v;
return *this;
}
FPS &operator/=(const FPS &r) {
if (this->size() < r.size()) {
this->clear();
return *this;
}
int n = this->size() - r.size() + 1;
if ((int)r.size() <= 64) {
FPS f(*this), g(r);
g.shrink();
mint coeff = g.back().inverse();
for (auto &x : g) x *= coeff;
int deg = (int)f.size() - (int)g.size() + 1;
int gs = g.size();
FPS quo(deg);
for (int i = deg - 1; i >= 0; i--) {
quo[i] = f[i + gs - 1];
for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
}
*this = quo * coeff;
this->resize(n, mint(0));
return *this;
}
return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev();
}
FPS &operator%=(const FPS &r) {
*this -= *this / r * r;
shrink();
return *this;
}
FPS operator+(const FPS &r) const { return FPS(*this) += r; }
FPS operator+(const mint &v) const { return FPS(*this) += v; }
FPS operator-(const FPS &r) const { return FPS(*this) -= r; }
FPS operator-(const mint &v) const { return FPS(*this) -= v; }
FPS operator*(const FPS &r) const { return FPS(*this) *= r; }
FPS operator*(const mint &v) const { return FPS(*this) *= v; }
FPS operator/(const FPS &r) const { return FPS(*this) /= r; }
FPS operator%(const FPS &r) const { return FPS(*this) %= r; }
FPS operator-() const {
FPS ret(this->size());
for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i];
return ret;
}
void shrink() {
while (this->size() && this->back() == mint(0)) this->pop_back();
}
FPS rev() const {
FPS ret(*this);
reverse(begin(ret), end(ret));
return ret;
}
FPS dot(FPS r) const {
FPS ret(min(this->size(), r.size()));
for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i];
return ret;
}
// 前 sz 項を取ってくる。sz に足りない項は 0 埋めする
FPS pre(int sz) const {
FPS ret(begin(*this), begin(*this) + min((int)this->size(), sz));
if ((int)ret.size() < sz) ret.resize(sz);
return ret;
}
FPS operator>>(int sz) const {
if ((int)this->size() <= sz) return {};
FPS ret(*this);
ret.erase(ret.begin(), ret.begin() + sz);
return ret;
}
FPS operator<<(int sz) const {
FPS ret(*this);
ret.insert(ret.begin(), sz, mint(0));
return ret;
}
FPS diff() const {
const int n = (int)this->size();
FPS ret(max(0, n - 1));
mint one(1), coeff(1);
for (int i = 1; i < n; i++) {
ret[i - 1] = (*this)[i] * coeff;
coeff += one;
}
return ret;
}
FPS integral() const {
const int n = (int)this->size();
FPS ret(n + 1);
ret[0] = mint(0);
if (n > 0) ret[1] = mint(1);
auto mod = mint::get_mod();
for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i);
for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i];
return ret;
}
mint eval(mint x) const {
mint r = 0, w = 1;
for (auto &v : *this) r += w * v, w *= x;
return r;
}
FPS log(int deg = -1) const {
assert(!(*this).empty() && (*this)[0] == mint(1));
if (deg == -1) deg = (int)this->size();
return (this->diff() * this->inv(deg)).pre(deg - 1).integral();
}
FPS pow(int64_t k, int deg = -1) const {
const int n = (int)this->size();
if (deg == -1) deg = n;
if (k == 0) {
FPS ret(deg);
if (deg) ret[0] = 1;
return ret;
}
for (int i = 0; i < n; i++) {
if ((*this)[i] != mint(0)) {
mint rev = mint(1) / (*this)[i];
FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg);
ret *= (*this)[i].pow(k);
ret = (ret << (i * k)).pre(deg);
if ((int)ret.size() < deg) ret.resize(deg, mint(0));
return ret;
}
if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0));
}
return FPS(deg, mint(0));
}
static void *ntt_ptr;
static void set_fft();
FPS &operator*=(const FPS &r);
void ntt();
void intt();
void ntt_doubling();
static int ntt_pr();
FPS inv(int deg = -1) const;
FPS exp(int deg = -1) const;
};
template <typename mint>
void *FormalPowerSeries<mint>::ntt_ptr = nullptr;
/**
* @brief 多項式/形式的冪級数ライブラリ
* @docs docs/fps/formal-power-series.md
*/
#line 2 "fps/mod-pow.hpp"
#line 4 "fps/mod-pow.hpp"
template <typename mint>
FormalPowerSeries<mint> mod_pow(int64_t k, const FormalPowerSeries<mint>& base,
const FormalPowerSeries<mint>& d) {
using fps = FormalPowerSeries<mint>;
assert(!d.empty());
auto inv = d.rev().inv();
auto quo = [&](const fps& poly) {
if (poly.size() < d.size()) return fps{};
int n = poly.size() - d.size() + 1;
return (poly.rev().pre(n) * inv.pre(n)).pre(n).rev();
};
fps res{1}, b(base);
while (k) {
if (k & 1) {
res *= b;
res -= quo(res) * d;
res.shrink();
}
b *= b;
b -= quo(b) * d;
b.shrink();
k >>= 1;
assert(b.size() + 1 <= d.size());
assert(res.size() + 1 <= d.size());
}
return res;
}
/**
* @brief Mod-Pow ($f(x)^k \mod g(x)$)
*/
#line 2 "misc/rng.hpp"
#line 2 "internal/internal-seed.hpp"
#include <chrono>
using namespace std;
namespace internal {
unsigned long long non_deterministic_seed() {
unsigned long long m =
chrono::duration_cast<chrono::nanoseconds>(
chrono::high_resolution_clock::now().time_since_epoch())
.count();
m ^= 9845834732710364265uLL;
m ^= m << 24, m ^= m >> 31, m ^= m << 35;
return m;
}
unsigned long long deterministic_seed() { return 88172645463325252UL; }
// 64 bit の seed 値を生成 (手元では seed 固定)
// 連続で呼び出すと同じ値が何度も返ってくるので注意
// #define RANDOMIZED_SEED するとシードがランダムになる
unsigned long long seed() {
#if defined(NyaanLocal) && !defined(RANDOMIZED_SEED)
return deterministic_seed();
#else
return non_deterministic_seed();
#endif
}
} // namespace internal
#line 4 "misc/rng.hpp"
namespace my_rand {
using i64 = long long;
using u64 = unsigned long long;
// [0, 2^64 - 1)
u64 rng() {
static u64 _x = internal::seed();
return _x ^= _x << 7, _x ^= _x >> 9;
}
// [l, r]
i64 rng(i64 l, i64 r) {
assert(l <= r);
return l + rng() % u64(r - l + 1);
}
// [l, r)
i64 randint(i64 l, i64 r) {
assert(l < r);
return l + rng() % u64(r - l);
}
// choose n numbers from [l, r) without overlapping
vector<i64> randset(i64 l, i64 r, i64 n) {
assert(l <= r && n <= r - l);
unordered_set<i64> s;
for (i64 i = n; i; --i) {
i64 m = randint(l, r + 1 - i);
if (s.find(m) != s.end()) m = r - i;
s.insert(m);
}
vector<i64> ret;
for (auto& x : s) ret.push_back(x);
sort(begin(ret), end(ret));
return ret;
}
// [0.0, 1.0)
double rnd() { return rng() * 5.42101086242752217004e-20; }
// [l, r)
double rnd(double l, double r) {
assert(l < r);
return l + rnd() * (r - l);
}
template <typename T>
void randshf(vector<T>& v) {
int n = v.size();
for (int i = 1; i < n; i++) swap(v[i], v[randint(0, i + 1)]);
}
} // namespace my_rand
using my_rand::randint;
using my_rand::randset;
using my_rand::randshf;
using my_rand::rnd;
using my_rand::rng;
#line 6 "matrix/black-box-linear-algebra.hpp"
//
namespace BBLAImpl {
template <typename mint>
mint inner_product(const FormalPowerSeries<mint>& a,
const FormalPowerSeries<mint>& b) {
mint res = 0;
int n = a.size();
assert(n == (int)b.size());
for (int i = 0; i < n; i++) res += a[i] * b[i];
return res;
}
template <typename mint>
FormalPowerSeries<mint> random_poly(int n) {
FormalPowerSeries<mint> res(n);
for (auto& x : res) x = randint(0, mint::get_mod());
return res;
}
template <typename mint>
struct ModMatrix : vector<FormalPowerSeries<mint>> {
using fps = FormalPowerSeries<mint>;
ModMatrix(int n) : vector<fps>(n, fps(n)) {}
inline void add(int i, int j, mint x) { (*this)[i][j] += x; }
friend fps operator*(const ModMatrix& m, const fps& r) {
int n = m.size();
assert(n == (int)r.size());
fps res(n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++) res[i] += m[i][j] * r[j];
return res;
}
void apply(int i, mint r) {
int n = (*this).size();
for (int j = 0; j < n; j++) (*this)[i][j] *= r;
}
};
template <typename mint>
struct SparseMatrix : vector<vector<pair<int, mint>>> {
using fps = FormalPowerSeries<mint>;
template <typename... Args>
SparseMatrix(Args... args) : vector<vector<pair<int, mint>>>(args...) {}
inline void add(int i, int j, mint x) { (*this)[i].emplace_back(j, x); }
friend fps operator*(const SparseMatrix& m, const fps& r) {
int n = m.size();
assert(n == (int)r.size());
fps res(n);
for (int i = 0; i < n; i++)
for (auto&& [j, x] : m[i]) res[i] += x * r[j];
return res;
}
void apply(int i, mint r) {
for (auto&& [_, x] : (*this)[i]) x *= r;
}
};
template <typename mint>
FormalPowerSeries<mint> vector_minpoly(
const vector<FormalPowerSeries<mint>>& b) {
assert(!b.empty());
int n = b.size(), m = b[0].size();
FormalPowerSeries<mint> u = random_poly<mint>(m), a(n);
for (int i = 0; i < n; i++) a[i] = inner_product(b[i], u);
auto mp = BerlekampMassey<mint>(a);
return {mp.begin(), mp.end()};
}
template <typename mint, typename Mat>
FormalPowerSeries<mint> mat_minpoly(const Mat& A) {
int n = A.size();
FormalPowerSeries<mint> u = random_poly<mint>(n);
vector<FormalPowerSeries<mint>> b(n * 2 + 1);
for (int i = 0; i < (int)b.size(); i++) b[i] = u, u = A * u;
FormalPowerSeries<mint> mp = vector_minpoly(b);
return mp;
}
// calculate A^k b
template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_pow(const Mat& A, FormalPowerSeries<mint> b,
int64_t k) {
using fps = FormalPowerSeries<mint>;
int n = b.size();
fps mp = mat_minpoly<mint, Mat>(A);
fps c = mod_pow<mint>(k, fps{0, 1}, mp.rev());
fps res(n);
for (int i = 0; i < (int)c.size(); i++) res += b * c[i], b = A * b;
return res;
}
template <typename mint, typename Mat>
mint fast_det(const Mat& A) {
using fps = FormalPowerSeries<mint>;
int n = A.size();
fps D;
while (true) {
do {
D = random_poly<mint>(n);
} while (any_of(begin(D), end(D), [](mint x) { return x == mint(0); }));
Mat AD = A;
for (int i = 0; i < n; i++) AD.apply(i, D[i]);
fps mp = mat_minpoly<mint, Mat>(AD);
if (mp.back() == 0) return 0;
if ((int)mp.size() != n + 1) continue;
mint det = n & 1 ? -mp.back() : mp.back();
mint Ddet = 1;
for (auto& d : D) Ddet *= d;
return det / Ddet;
}
exit(1);
}
template <typename mint, typename Mat>
FormalPowerSeries<mint> fast_linear_equation(const Mat& A, const FormalPowerSeries<mint>& b) {
using fps = FormalPowerSeries<mint>;
int n = A.size();
fps mp = mat_minpoly<mint, Mat>(A).rev();
fps buf = b, res(n);
for (int i = 1; i < (int)mp.size(); i++) {
res = buf * mp[i];
buf = A * buf;
}
return buf * mp[0].inverse();
}
} // namespace BBLAImpl
using BBLAImpl::fast_det;
using BBLAImpl::fast_pow;
using BBLAImpl::ModMatrix;
using BBLAImpl::SparseMatrix;
/**
* @brief Black Box Linear Algebra
*/
Back to top page