#include "math-fast/subset-convolution.hpp"
#pragma once #include <cassert> #include <vector> using namespace std; #include "../modint/vectorize-modint.hpp" template <typename mint> vector<mint> fast_multiply(const vector<mint>& a, const vector<mint>& b) { int n = a.size(); int d = __builtin_ctz(n); assert(d <= 23); mmint* a1 = new mmint[max(n, 8) * 3]; mmint* a2 = new mmint[max(n, 8) * 3]; memset((void*)a1, 0, max(n, 8) * 3 * sizeof(mmint)); memset((void*)a2, 0, max(n, 8) * 3 * sizeof(mmint)); mmint b1[24], b2[24], b3[24]; for (int i = 0; i < n; i++) { unsigned int pc = __builtin_popcount(i); a1[i * 3 + pc / 8][pc % 8] = a[i].a; a2[i * 3 + pc / 8][pc % 8] = b[i].a; } for (int j = 2; j <= n; j += 2) { unsigned int pc = __builtin_popcount(j); unsigned int ctz = __builtin_ctz(j); for (int h = 0; h < d; h++) { if (j & (1 << h)) break; int li = j - 2 * (1 << h), ri = j - (1 << h); if (pc + ctz <= 16) { for (int i = 0; i < 3 * (1 << h); i += 3) { a1[ri * 3 + i + 0] += a1[li * 3 + i + 0]; a2[ri * 3 + i + 0] += a2[li * 3 + i + 0]; a1[ri * 3 + i + 1] += a1[li * 3 + i + 1]; a2[ri * 3 + i + 1] += a2[li * 3 + i + 1]; } } else { for (int i = 0; i < 3 * (1 << h); i++) { a1[ri * 3 + i] += a1[li * 3 + i]; a2[ri * 3 + i] += a2[li * 3 + i]; } } } } mmint th = _mm256_set1_epi64x(4LL * mmint::M1[0] * mmint::M1[0]); for (int is = 0; is < n; is += 8) { int mpc = d; for (int i = is; i < is + 8; i++) { int pc = __builtin_popcount(i); mpc = min(mpc, pc); for (int j = 0; j <= d; j++) { b1[j][i - is] = a1[i * 3 + j / 8][j % 8]; b2[j][i - is] = a2[i * 3 + j / 8][j % 8]; b3[j][i - is] = 0; } } for (int j = 0; j <= d; j++) { m256 cmpB1 = _mm256_cmpgt_epi32(mmint::M1, b1[j]); m256 cmpB2 = _mm256_cmpgt_epi32(mmint::M1, b2[j]); m256 difB1 = _mm256_andnot_si256(cmpB1, mmint::M1); m256 difB2 = _mm256_andnot_si256(cmpB2, mmint::M1); b1[j] = _mm256_sub_epi32(b1[j], difB1); b2[j] = _mm256_sub_epi32(b2[j], difB2); } #define PROD(k) \ m256 A13##k = _mm256_shuffle_epi32(b1[j + k], 0xF5); \ m256 B13##k = _mm256_shuffle_epi32(b2[l - j - k], 0xF5); \ m256 p02##k = _mm256_mul_epi32(b1[j + k], b2[l - j - k]); \ m256 p13##k = _mm256_mul_epi32(A13##k, B13##k); \ prod02 = _mm256_add_epi64(prod02, p02##k); \ prod13 = _mm256_add_epi64(prod13, p13##k) #define COMP() \ do { \ m256 cmp02 = _mm256_cmpgt_epi64(prod02, th); \ m256 cmp13 = _mm256_cmpgt_epi64(prod13, th); \ m256 dif02 = _mm256_and_si256(cmp02, th); \ m256 dif13 = _mm256_and_si256(cmp13, th); \ prod02 = _mm256_sub_epi64(prod02, dif02); \ prod13 = _mm256_sub_epi64(prod13, dif13); \ } while (0) for (int l = mpc; l <= d; l++) { int j = 0; mmint prod02 = mmint::M0, prod13 = mmint::M0; for (; j <= l - 3; j += 4) { PROD(0); PROD(1); PROD(2); PROD(3); COMP(); } for (; j <= l; j++) { PROD(0); } COMP(); b3[l] = mmint::reduce(prod02, prod13); } #undef PROD #undef COMP for (int i = is; i < is + 8; i++) { for (unsigned j = mpc; j <= unsigned(d); j++) { a1[i * 3 + j / 8][j % 8] = b3[j][i - is]; } } } for (int j = 2; j <= n; j += 2) { for (int h = 0; h < d; h++) { if (j & (1 << h)) break; int li = j - 2 * (1 << h), ri = j - (1 << h); for (int i = 0; i < 3 * (1 << h); i++) { a1[ri * 3 + i] -= a1[li * 3 + i]; } } } vector<mint> c(n); for (int i = 0; i < n; i++) { unsigned int pc = __builtin_popcount(i); c[i].a = a1[i * 3 + pc / 8][pc % 8]; } delete[] (a1); delete[] (a2); return c; }
#line 2 "math-fast/subset-convolution.hpp" #include <cassert> #include <vector> using namespace std; #line 2 "modint/vectorize-modint.hpp" #pragma GCC optimize("O3,unroll-loops") #pragma GCC target("avx2") #include <immintrin.h> #include <iostream> using namespace std; using m256 = __m256i; struct alignas(32) mmint { m256 x; static mmint R, M0, M1, M2, N2; mmint() : x() {} inline mmint(const m256& _x) : x(_x) {} inline mmint(unsigned int a) : x(_mm256_set1_epi32(a)) {} inline mmint(unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4, unsigned int a5, unsigned int a6, unsigned int a7) : x(_mm256_set_epi32(a7, a6, a5, a4, a3, a2, a1, a0)) {} inline operator m256&() { return x; } inline operator const m256&() const { return x; } inline int& operator[](int i) { return *(reinterpret_cast<int*>(&x) + i); } inline const int& operator[](int i) const { return *(reinterpret_cast<const int*>(&x) + i); } friend ostream& operator<<(ostream& os, const mmint& m) { unsigned r = R[0], mod = M1[0]; auto reduce1 = [&](const uint64_t& b) { unsigned res = (b + uint64_t(unsigned(b) * unsigned(-r)) * mod) >> 32; return res >= mod ? res - mod : res; }; for (int i = 0; i < 8; i++) { os << reduce1(m[i]) << (i == 7 ? "" : " "); } return os; } template <typename mint> static void set_mod() { R = _mm256_set1_epi32(mint::r); M0 = _mm256_setzero_si256(); M1 = _mm256_set1_epi32(mint::get_mod()); M2 = _mm256_set1_epi32(mint::get_mod() * 2); N2 = _mm256_set1_epi32(mint::n2); } static inline mmint reduce(const mmint& prod02, const mmint& prod13) { m256 unpalo = _mm256_unpacklo_epi32(prod02, prod13); m256 unpahi = _mm256_unpackhi_epi32(prod02, prod13); m256 prodlo = _mm256_unpacklo_epi64(unpalo, unpahi); m256 prodhi = _mm256_unpackhi_epi64(unpalo, unpahi); m256 hiplm1 = _mm256_add_epi32(prodhi, M1); m256 prodlohi = _mm256_shuffle_epi32(prodlo, 0xF5); m256 lmlr02 = _mm256_mul_epu32(prodlo, R); m256 lmlr13 = _mm256_mul_epu32(prodlohi, R); m256 prod02_ = _mm256_mul_epu32(lmlr02, M1); m256 prod13_ = _mm256_mul_epu32(lmlr13, M1); m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_); m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_); m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_); return _mm256_sub_epi32(hiplm1, prod); } static inline mmint itom(const mmint& A) { return A * N2; } static inline mmint mtoi(const mmint& A) { m256 A13 = _mm256_shuffle_epi32(A, 0xF5); m256 lmlr02 = _mm256_mul_epu32(A, R); m256 lmlr13 = _mm256_mul_epu32(A13, R); m256 prod02_ = _mm256_mul_epu32(lmlr02, M1); m256 prod13_ = _mm256_mul_epu32(lmlr13, M1); m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_); m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_); m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_); m256 cmp = _mm256_cmpgt_epi32(prod, M0); m256 dif = _mm256_and_si256(cmp, M1); return _mm256_sub_epi32(dif, prod); } friend inline mmint operator+(const mmint& A, const mmint& B) { m256 apb = _mm256_add_epi32(A, B); m256 ret = _mm256_sub_epi32(apb, M2); m256 cmp = _mm256_cmpgt_epi32(M0, ret); m256 add = _mm256_and_si256(cmp, M2); return _mm256_add_epi32(add, ret); } friend inline mmint operator-(const mmint& A, const mmint& B) { m256 ret = _mm256_sub_epi32(A, B); m256 cmp = _mm256_cmpgt_epi32(M0, ret); m256 add = _mm256_and_si256(cmp, M2); return _mm256_add_epi32(add, ret); } friend inline mmint operator*(const mmint& A, const mmint& B) { m256 a13 = _mm256_shuffle_epi32(A, 0xF5); m256 b13 = _mm256_shuffle_epi32(B, 0xF5); m256 prod02 = _mm256_mul_epu32(A, B); m256 prod13 = _mm256_mul_epu32(a13, b13); return reduce(prod02, prod13); } inline mmint& operator+=(const mmint& A) { return (*this) = (*this) + A; } inline mmint& operator-=(const mmint& A) { return (*this) = (*this) - A; } inline mmint& operator*=(const mmint& A) { return (*this) = (*this) * A; } bool operator==(const mmint& A) { m256 sub = _mm256_sub_epi32(x, A.x); return _mm256_testz_si256(sub, sub) == 1; } bool operator!=(const mmint& A) { return !((*this) == A); } }; __attribute__((aligned(32))) mmint mmint::R; __attribute__((aligned(32))) mmint mmint::M0, mmint::M1, mmint::M2, mmint::N2; /** * @brief vectorize modint */ #line 8 "math-fast/subset-convolution.hpp" template <typename mint> vector<mint> fast_multiply(const vector<mint>& a, const vector<mint>& b) { int n = a.size(); int d = __builtin_ctz(n); assert(d <= 23); mmint* a1 = new mmint[max(n, 8) * 3]; mmint* a2 = new mmint[max(n, 8) * 3]; memset((void*)a1, 0, max(n, 8) * 3 * sizeof(mmint)); memset((void*)a2, 0, max(n, 8) * 3 * sizeof(mmint)); mmint b1[24], b2[24], b3[24]; for (int i = 0; i < n; i++) { unsigned int pc = __builtin_popcount(i); a1[i * 3 + pc / 8][pc % 8] = a[i].a; a2[i * 3 + pc / 8][pc % 8] = b[i].a; } for (int j = 2; j <= n; j += 2) { unsigned int pc = __builtin_popcount(j); unsigned int ctz = __builtin_ctz(j); for (int h = 0; h < d; h++) { if (j & (1 << h)) break; int li = j - 2 * (1 << h), ri = j - (1 << h); if (pc + ctz <= 16) { for (int i = 0; i < 3 * (1 << h); i += 3) { a1[ri * 3 + i + 0] += a1[li * 3 + i + 0]; a2[ri * 3 + i + 0] += a2[li * 3 + i + 0]; a1[ri * 3 + i + 1] += a1[li * 3 + i + 1]; a2[ri * 3 + i + 1] += a2[li * 3 + i + 1]; } } else { for (int i = 0; i < 3 * (1 << h); i++) { a1[ri * 3 + i] += a1[li * 3 + i]; a2[ri * 3 + i] += a2[li * 3 + i]; } } } } mmint th = _mm256_set1_epi64x(4LL * mmint::M1[0] * mmint::M1[0]); for (int is = 0; is < n; is += 8) { int mpc = d; for (int i = is; i < is + 8; i++) { int pc = __builtin_popcount(i); mpc = min(mpc, pc); for (int j = 0; j <= d; j++) { b1[j][i - is] = a1[i * 3 + j / 8][j % 8]; b2[j][i - is] = a2[i * 3 + j / 8][j % 8]; b3[j][i - is] = 0; } } for (int j = 0; j <= d; j++) { m256 cmpB1 = _mm256_cmpgt_epi32(mmint::M1, b1[j]); m256 cmpB2 = _mm256_cmpgt_epi32(mmint::M1, b2[j]); m256 difB1 = _mm256_andnot_si256(cmpB1, mmint::M1); m256 difB2 = _mm256_andnot_si256(cmpB2, mmint::M1); b1[j] = _mm256_sub_epi32(b1[j], difB1); b2[j] = _mm256_sub_epi32(b2[j], difB2); } #define PROD(k) \ m256 A13##k = _mm256_shuffle_epi32(b1[j + k], 0xF5); \ m256 B13##k = _mm256_shuffle_epi32(b2[l - j - k], 0xF5); \ m256 p02##k = _mm256_mul_epi32(b1[j + k], b2[l - j - k]); \ m256 p13##k = _mm256_mul_epi32(A13##k, B13##k); \ prod02 = _mm256_add_epi64(prod02, p02##k); \ prod13 = _mm256_add_epi64(prod13, p13##k) #define COMP() \ do { \ m256 cmp02 = _mm256_cmpgt_epi64(prod02, th); \ m256 cmp13 = _mm256_cmpgt_epi64(prod13, th); \ m256 dif02 = _mm256_and_si256(cmp02, th); \ m256 dif13 = _mm256_and_si256(cmp13, th); \ prod02 = _mm256_sub_epi64(prod02, dif02); \ prod13 = _mm256_sub_epi64(prod13, dif13); \ } while (0) for (int l = mpc; l <= d; l++) { int j = 0; mmint prod02 = mmint::M0, prod13 = mmint::M0; for (; j <= l - 3; j += 4) { PROD(0); PROD(1); PROD(2); PROD(3); COMP(); } for (; j <= l; j++) { PROD(0); } COMP(); b3[l] = mmint::reduce(prod02, prod13); } #undef PROD #undef COMP for (int i = is; i < is + 8; i++) { for (unsigned j = mpc; j <= unsigned(d); j++) { a1[i * 3 + j / 8][j % 8] = b3[j][i - is]; } } } for (int j = 2; j <= n; j += 2) { for (int h = 0; h < d; h++) { if (j & (1 << h)) break; int li = j - 2 * (1 << h), ri = j - (1 << h); for (int i = 0; i < 3 * (1 << h); i++) { a1[ri * 3 + i] -= a1[li * 3 + i]; } } } vector<mint> c(n); for (int i = 0; i < n; i++) { unsigned int pc = __builtin_popcount(i); c[i].a = a1[i * 3 + pc / 8][pc % 8]; } delete[] (a1); delete[] (a2); return c; }