Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: math-fast/subset-convolution.hpp

Depends on

Verified with

Code

#pragma once

#include <cassert>
#include <vector>
using namespace std;

#include "../modint/vectorize-modint.hpp"

template <typename mint>
vector<mint> fast_multiply(const vector<mint>& a, const vector<mint>& b) {
  int n = a.size();
  int d = __builtin_ctz(n);
  assert(d <= 23);
  mmint* a1 = new mmint[max(n, 8) * 3];
  mmint* a2 = new mmint[max(n, 8) * 3];
  memset((void*)a1, 0, max(n, 8) * 3 * sizeof(mmint));
  memset((void*)a2, 0, max(n, 8) * 3 * sizeof(mmint));
  mmint b1[24], b2[24], b3[24];

  for (int i = 0; i < n; i++) {
    unsigned int pc = __builtin_popcount(i);
    a1[i * 3 + pc / 8][pc % 8] = a[i].a;
    a2[i * 3 + pc / 8][pc % 8] = b[i].a;
  }

  for (int j = 2; j <= n; j += 2) {
    unsigned int pc = __builtin_popcount(j);
    unsigned int ctz = __builtin_ctz(j);
    for (int h = 0; h < d; h++) {
      if (j & (1 << h)) break;
      int li = j - 2 * (1 << h), ri = j - (1 << h);
      if (pc + ctz <= 16) {
        for (int i = 0; i < 3 * (1 << h); i += 3) {
          a1[ri * 3 + i + 0] += a1[li * 3 + i + 0];
          a2[ri * 3 + i + 0] += a2[li * 3 + i + 0];
          a1[ri * 3 + i + 1] += a1[li * 3 + i + 1];
          a2[ri * 3 + i + 1] += a2[li * 3 + i + 1];
        }
      } else {
        for (int i = 0; i < 3 * (1 << h); i++) {
          a1[ri * 3 + i] += a1[li * 3 + i];
          a2[ri * 3 + i] += a2[li * 3 + i];
        }
      }
    }
  }

  mmint th = _mm256_set1_epi64x(4LL * mmint::M1[0] * mmint::M1[0]);

  for (int is = 0; is < n; is += 8) {
    int mpc = d;

    for (int i = is; i < is + 8; i++) {
      int pc = __builtin_popcount(i);
      mpc = min(mpc, pc);
      for (int j = 0; j <= d; j++) {
        b1[j][i - is] = a1[i * 3 + j / 8][j % 8];
        b2[j][i - is] = a2[i * 3 + j / 8][j % 8];
        b3[j][i - is] = 0;
      }
    }

    for (int j = 0; j <= d; j++) {
      m256 cmpB1 = _mm256_cmpgt_epi32(mmint::M1, b1[j]);
      m256 cmpB2 = _mm256_cmpgt_epi32(mmint::M1, b2[j]);
      m256 difB1 = _mm256_andnot_si256(cmpB1, mmint::M1);
      m256 difB2 = _mm256_andnot_si256(cmpB2, mmint::M1);
      b1[j] = _mm256_sub_epi32(b1[j], difB1);
      b2[j] = _mm256_sub_epi32(b2[j], difB2);
    }

#define PROD(k)                                             \
  m256 A13##k = _mm256_shuffle_epi32(b1[j + k], 0xF5);      \
  m256 B13##k = _mm256_shuffle_epi32(b2[l - j - k], 0xF5);  \
  m256 p02##k = _mm256_mul_epi32(b1[j + k], b2[l - j - k]); \
  m256 p13##k = _mm256_mul_epi32(A13##k, B13##k);           \
  prod02 = _mm256_add_epi64(prod02, p02##k);                \
  prod13 = _mm256_add_epi64(prod13, p13##k)
#define COMP()                                   \
  do {                                           \
    m256 cmp02 = _mm256_cmpgt_epi64(prod02, th); \
    m256 cmp13 = _mm256_cmpgt_epi64(prod13, th); \
    m256 dif02 = _mm256_and_si256(cmp02, th);    \
    m256 dif13 = _mm256_and_si256(cmp13, th);    \
    prod02 = _mm256_sub_epi64(prod02, dif02);    \
    prod13 = _mm256_sub_epi64(prod13, dif13);    \
  } while (0)

    for (int l = mpc; l <= d; l++) {
      int j = 0;
      mmint prod02 = mmint::M0, prod13 = mmint::M0;

      for (; j <= l - 3; j += 4) {
        PROD(0);
        PROD(1);
        PROD(2);
        PROD(3);
        COMP();
      }
      for (; j <= l; j++) {
        PROD(0);
      }
      COMP();
      b3[l] = mmint::reduce(prod02, prod13);
    }

#undef PROD
#undef COMP

    for (int i = is; i < is + 8; i++) {
      for (unsigned j = mpc; j <= unsigned(d); j++) {
        a1[i * 3 + j / 8][j % 8] = b3[j][i - is];
      }
    }
  }

  for (int j = 2; j <= n; j += 2) {
    for (int h = 0; h < d; h++) {
      if (j & (1 << h)) break;
      int li = j - 2 * (1 << h), ri = j - (1 << h);
      for (int i = 0; i < 3 * (1 << h); i++) {
        a1[ri * 3 + i] -= a1[li * 3 + i];
      }
    }
  }

  vector<mint> c(n);
  for (int i = 0; i < n; i++) {
    unsigned int pc = __builtin_popcount(i);
    c[i].a = a1[i * 3 + pc / 8][pc % 8];
  }

  delete[] (a1);
  delete[] (a2);
  return c;
}
#line 2 "math-fast/subset-convolution.hpp"

#include <cassert>
#include <vector>
using namespace std;

#line 2 "modint/vectorize-modint.hpp"

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2")

#include <immintrin.h>
#include <iostream>
using namespace std;

using m256 = __m256i;
struct alignas(32) mmint {
  m256 x;
  static mmint R, M0, M1, M2, N2;

  mmint() : x() {}
  inline mmint(const m256& _x) : x(_x) {}
  inline mmint(unsigned int a) : x(_mm256_set1_epi32(a)) {}
  inline mmint(unsigned int a0, unsigned int a1, unsigned int a2,
               unsigned int a3, unsigned int a4, unsigned int a5,
               unsigned int a6, unsigned int a7)
      : x(_mm256_set_epi32(a7, a6, a5, a4, a3, a2, a1, a0)) {}
  inline operator m256&() { return x; }
  inline operator const m256&() const { return x; }
  inline int& operator[](int i) { return *(reinterpret_cast<int*>(&x) + i); }
  inline const int& operator[](int i) const {
    return *(reinterpret_cast<const int*>(&x) + i);
  }

  friend ostream& operator<<(ostream& os, const mmint& m) {
    unsigned r = R[0], mod = M1[0];
    auto reduce1 = [&](const uint64_t& b) {
      unsigned res = (b + uint64_t(unsigned(b) * unsigned(-r)) * mod) >> 32;
      return res >= mod ? res - mod : res;
    };
    for (int i = 0; i < 8; i++) {
      os << reduce1(m[i]) << (i == 7 ? "" : " ");
    }
    return os;
  }

  template <typename mint>
  static void set_mod() {
    R = _mm256_set1_epi32(mint::r);
    M0 = _mm256_setzero_si256();
    M1 = _mm256_set1_epi32(mint::get_mod());
    M2 = _mm256_set1_epi32(mint::get_mod() * 2);
    N2 = _mm256_set1_epi32(mint::n2);
  }

  static inline mmint reduce(const mmint& prod02, const mmint& prod13) {
    m256 unpalo = _mm256_unpacklo_epi32(prod02, prod13);
    m256 unpahi = _mm256_unpackhi_epi32(prod02, prod13);
    m256 prodlo = _mm256_unpacklo_epi64(unpalo, unpahi);
    m256 prodhi = _mm256_unpackhi_epi64(unpalo, unpahi);
    m256 hiplm1 = _mm256_add_epi32(prodhi, M1);
    m256 prodlohi = _mm256_shuffle_epi32(prodlo, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(prodlo, R);
    m256 lmlr13 = _mm256_mul_epu32(prodlohi, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    return _mm256_sub_epi32(hiplm1, prod);
  }

  static inline mmint itom(const mmint& A) { return A * N2; }

  static inline mmint mtoi(const mmint& A) {
    m256 A13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(A, R);
    m256 lmlr13 = _mm256_mul_epu32(A13, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    m256 cmp = _mm256_cmpgt_epi32(prod, M0);
    m256 dif = _mm256_and_si256(cmp, M1);
    return _mm256_sub_epi32(dif, prod);
  }

  friend inline mmint operator+(const mmint& A, const mmint& B) {
    m256 apb = _mm256_add_epi32(A, B);
    m256 ret = _mm256_sub_epi32(apb, M2);
    m256 cmp = _mm256_cmpgt_epi32(M0, ret);
    m256 add = _mm256_and_si256(cmp, M2);
    return _mm256_add_epi32(add, ret);
  }

  friend inline mmint operator-(const mmint& A, const mmint& B) {
    m256 ret = _mm256_sub_epi32(A, B);
    m256 cmp = _mm256_cmpgt_epi32(M0, ret);
    m256 add = _mm256_and_si256(cmp, M2);
    return _mm256_add_epi32(add, ret);
  }

  friend inline mmint operator*(const mmint& A, const mmint& B) {
    m256 a13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 b13 = _mm256_shuffle_epi32(B, 0xF5);
    m256 prod02 = _mm256_mul_epu32(A, B);
    m256 prod13 = _mm256_mul_epu32(a13, b13);
    return reduce(prod02, prod13);
  }

  inline mmint& operator+=(const mmint& A) { return (*this) = (*this) + A; }
  inline mmint& operator-=(const mmint& A) { return (*this) = (*this) - A; }
  inline mmint& operator*=(const mmint& A) { return (*this) = (*this) * A; }

  bool operator==(const mmint& A) {
    m256 sub = _mm256_sub_epi32(x, A.x);
    return _mm256_testz_si256(sub, sub) == 1;
  }
  bool operator!=(const mmint& A) { return !((*this) == A); }
};
__attribute__((aligned(32))) mmint mmint::R;
__attribute__((aligned(32))) mmint mmint::M0, mmint::M1, mmint::M2, mmint::N2;

/**
 * @brief vectorize modint
 */
#line 8 "math-fast/subset-convolution.hpp"

template <typename mint>
vector<mint> fast_multiply(const vector<mint>& a, const vector<mint>& b) {
  int n = a.size();
  int d = __builtin_ctz(n);
  assert(d <= 23);
  mmint* a1 = new mmint[max(n, 8) * 3];
  mmint* a2 = new mmint[max(n, 8) * 3];
  memset((void*)a1, 0, max(n, 8) * 3 * sizeof(mmint));
  memset((void*)a2, 0, max(n, 8) * 3 * sizeof(mmint));
  mmint b1[24], b2[24], b3[24];

  for (int i = 0; i < n; i++) {
    unsigned int pc = __builtin_popcount(i);
    a1[i * 3 + pc / 8][pc % 8] = a[i].a;
    a2[i * 3 + pc / 8][pc % 8] = b[i].a;
  }

  for (int j = 2; j <= n; j += 2) {
    unsigned int pc = __builtin_popcount(j);
    unsigned int ctz = __builtin_ctz(j);
    for (int h = 0; h < d; h++) {
      if (j & (1 << h)) break;
      int li = j - 2 * (1 << h), ri = j - (1 << h);
      if (pc + ctz <= 16) {
        for (int i = 0; i < 3 * (1 << h); i += 3) {
          a1[ri * 3 + i + 0] += a1[li * 3 + i + 0];
          a2[ri * 3 + i + 0] += a2[li * 3 + i + 0];
          a1[ri * 3 + i + 1] += a1[li * 3 + i + 1];
          a2[ri * 3 + i + 1] += a2[li * 3 + i + 1];
        }
      } else {
        for (int i = 0; i < 3 * (1 << h); i++) {
          a1[ri * 3 + i] += a1[li * 3 + i];
          a2[ri * 3 + i] += a2[li * 3 + i];
        }
      }
    }
  }

  mmint th = _mm256_set1_epi64x(4LL * mmint::M1[0] * mmint::M1[0]);

  for (int is = 0; is < n; is += 8) {
    int mpc = d;

    for (int i = is; i < is + 8; i++) {
      int pc = __builtin_popcount(i);
      mpc = min(mpc, pc);
      for (int j = 0; j <= d; j++) {
        b1[j][i - is] = a1[i * 3 + j / 8][j % 8];
        b2[j][i - is] = a2[i * 3 + j / 8][j % 8];
        b3[j][i - is] = 0;
      }
    }

    for (int j = 0; j <= d; j++) {
      m256 cmpB1 = _mm256_cmpgt_epi32(mmint::M1, b1[j]);
      m256 cmpB2 = _mm256_cmpgt_epi32(mmint::M1, b2[j]);
      m256 difB1 = _mm256_andnot_si256(cmpB1, mmint::M1);
      m256 difB2 = _mm256_andnot_si256(cmpB2, mmint::M1);
      b1[j] = _mm256_sub_epi32(b1[j], difB1);
      b2[j] = _mm256_sub_epi32(b2[j], difB2);
    }

#define PROD(k)                                             \
  m256 A13##k = _mm256_shuffle_epi32(b1[j + k], 0xF5);      \
  m256 B13##k = _mm256_shuffle_epi32(b2[l - j - k], 0xF5);  \
  m256 p02##k = _mm256_mul_epi32(b1[j + k], b2[l - j - k]); \
  m256 p13##k = _mm256_mul_epi32(A13##k, B13##k);           \
  prod02 = _mm256_add_epi64(prod02, p02##k);                \
  prod13 = _mm256_add_epi64(prod13, p13##k)
#define COMP()                                   \
  do {                                           \
    m256 cmp02 = _mm256_cmpgt_epi64(prod02, th); \
    m256 cmp13 = _mm256_cmpgt_epi64(prod13, th); \
    m256 dif02 = _mm256_and_si256(cmp02, th);    \
    m256 dif13 = _mm256_and_si256(cmp13, th);    \
    prod02 = _mm256_sub_epi64(prod02, dif02);    \
    prod13 = _mm256_sub_epi64(prod13, dif13);    \
  } while (0)

    for (int l = mpc; l <= d; l++) {
      int j = 0;
      mmint prod02 = mmint::M0, prod13 = mmint::M0;

      for (; j <= l - 3; j += 4) {
        PROD(0);
        PROD(1);
        PROD(2);
        PROD(3);
        COMP();
      }
      for (; j <= l; j++) {
        PROD(0);
      }
      COMP();
      b3[l] = mmint::reduce(prod02, prod13);
    }

#undef PROD
#undef COMP

    for (int i = is; i < is + 8; i++) {
      for (unsigned j = mpc; j <= unsigned(d); j++) {
        a1[i * 3 + j / 8][j % 8] = b3[j][i - is];
      }
    }
  }

  for (int j = 2; j <= n; j += 2) {
    for (int h = 0; h < d; h++) {
      if (j & (1 << h)) break;
      int li = j - 2 * (1 << h), ri = j - (1 << h);
      for (int i = 0; i < 3 * (1 << h); i++) {
        a1[ri * 3 + i] -= a1[li * 3 + i];
      }
    }
  }

  vector<mint> c(n);
  for (int i = 0; i < n; i++) {
    unsigned int pc = __builtin_popcount(i);
    c[i].a = a1[i * 3 + pc / 8][pc % 8];
  }

  delete[] (a1);
  delete[] (a2);
  return c;
}
Back to top page