#include "ntt/multiplicative-convolution-mod-p.hpp"
#pragma once #include <cassert> #include <vector> using namespace std; #include "../math/constexpr-primitive-root.hpp" template <typename fps> fps multiplicative_convolution_mod_p(const fps& a, const fps& b, int p) { assert((int)a.size() == p); assert((int)b.size() == p); using mint = typename fps::value_type; int r = constexpr_primitive_root(p); vector<int> exp(p - 1), log(p); exp[0] = 1; for (int i = 1; i < p - 1; i++) exp[i] = 1LL * exp[i - 1] * r % p; for (int i = 0; i < p - 1; i++) log[exp[i]] = i; fps s(p - 1), t(p - 1); for (int i = 0; i < p - 1; i++) s[i] = a[exp[i]], t[i] = b[exp[i]]; fps u = s * t; for (int i = p - 1; i < (int)u.size(); i++) u[i % (p - 1)] += u[i]; fps c(p); for (int i = 1; i < p; i++) c[i] = u[log[i]]; mint sa = accumulate(begin(a), end(a), mint{}); mint sb = accumulate(begin(b), end(b), mint{}); c[0] = sa * b[0] + sb * a[0] - a[0] * b[0]; return c; }
#line 2 "ntt/multiplicative-convolution-mod-p.hpp" #include <cassert> #include <vector> using namespace std; #line 2 "math/constexpr-primitive-root.hpp" constexpr unsigned int constexpr_primitive_root(unsigned int mod) { using u32 = unsigned int; using u64 = unsigned long long; if(mod == 2) return 1; u64 m = mod - 1, ds[32] = {}, idx = 0; for (u64 i = 2; i * i <= m; ++i) { if (m % i == 0) { ds[idx++] = i; while (m % i == 0) m /= i; } } if (m != 1) ds[idx++] = m; for (u32 _pr = 2, flg = true;; _pr++, flg = true) { for (u32 i = 0; i < idx && flg; ++i) { u64 a = _pr, b = (mod - 1) / ds[i], r = 1; for (; b; a = a * a % mod, b >>= 1) if (b & 1) r = r * a % mod; if (r == 1) flg = false; } if (flg == true) return _pr; } } #line 8 "ntt/multiplicative-convolution-mod-p.hpp" template <typename fps> fps multiplicative_convolution_mod_p(const fps& a, const fps& b, int p) { assert((int)a.size() == p); assert((int)b.size() == p); using mint = typename fps::value_type; int r = constexpr_primitive_root(p); vector<int> exp(p - 1), log(p); exp[0] = 1; for (int i = 1; i < p - 1; i++) exp[i] = 1LL * exp[i - 1] * r % p; for (int i = 0; i < p - 1; i++) log[exp[i]] = i; fps s(p - 1), t(p - 1); for (int i = 0; i < p - 1; i++) s[i] = a[exp[i]], t[i] = b[exp[i]]; fps u = s * t; for (int i = p - 1; i < (int)u.size(); i++) u[i % (p - 1)] += u[i]; fps c(p); for (int i = 1; i < p; i++) c[i] = u[log[i]]; mint sa = accumulate(begin(a), end(a), mint{}); mint sb = accumulate(begin(b), end(b), mint{}); c[0] = sa * b[0] + sb * a[0] - a[0] * b[0]; return c; }