Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:warning: math-fast/vectorize-modint.hpp

Code

#pragma once

#include <immintrin.h>

#include <array>

using namespace std;

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2")

template <typename _mint>
struct alignas(32) vectorize_modint {
  using mint = _mint;
  using m256 = __m256i;
  using vmint = vectorize_modint;
  m256 x;
  inline static vmint R = mint::r;
  inline static vmint M0 = 0;
  inline static vmint M1 = mint::get_mod();
  inline static vmint M2 = mint::get_mod() * 2;
  inline static vmint N2 = mint::n2;
  vectorize_modint() = default;
  vectorize_modint(int a) : x(_mm256_set1_epi32(a)) {}
  vectorize_modint(const m256& _x) : x(_x) {}
  vectorize_modint(const array<int, 8>& a)
      : x(_mm256_loadu_si256((m256*)a.data())) {}
  vectorize_modint(int a0, int a1, int a2, int a3, int a4, int a5, int a6,
                   int a7)
      : x(_mm256_set_epi32(a7, a6, a5, a4, a3, a2, a1, a0)) {}
  int at(int i) const {
    /*
    alignas(32) array<int, 8> b;
    _mm256_store_si256((m256*)b.data(), x);
    return b[i];
    */
    return *(reinterpret_cast<const int*>(&x) + i);
  }
  void set(int i, int val) {
    /*
    alignas(32) array<int, 8> b;
    _mm256_store_si256((m256*)b.data(), x);
    b[i] = val;
    x = _mm256_load_si256((m256*)b.data());
    */
    *(reinterpret_cast<int*>(&x) + i) = val;
  }
  operator const __m256i&() const { return x; }
  friend ostream& operator<<(ostream& os, const vmint& m) {
    vmint a = mtoi(m);
    for (int i = 0; i < 8; i++) os << a.at(i) << (i == 7 ? "" : " ");
    return os;
  }
  static vmint reduce(const vmint& prod02, const vmint& prod13) {
    m256 unpalo = _mm256_unpacklo_epi32(prod02, prod13);
    m256 unpahi = _mm256_unpackhi_epi32(prod02, prod13);
    m256 prodlo = _mm256_unpacklo_epi64(unpalo, unpahi);
    m256 prodhi = _mm256_unpackhi_epi64(unpalo, unpahi);
    m256 hiplm1 = _mm256_add_epi32(prodhi, M1);
    m256 prodlohi = _mm256_shuffle_epi32(prodlo, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(prodlo, R);
    m256 lmlr13 = _mm256_mul_epu32(prodlohi, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    return _mm256_sub_epi32(hiplm1, prod);
  }
  static vmint itom(const vmint& A) { return A * N2; }
  static vmint mtoi(const vmint& A) {
    m256 A13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(A, R);
    m256 lmlr13 = _mm256_mul_epu32(A13, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    m256 cmp = _mm256_cmpgt_epi32(prod, M0);
    m256 dif = _mm256_and_si256(cmp, M1);
    return _mm256_sub_epi32(dif, prod);
  }
  __attribute__((target("avx2"), optimize("O3", "unroll-loops"))) friend vmint
  operator+(const vmint& A, const vmint& B) {
    m256 apb = _mm256_add_epi32(A, B);
    m256 ret = _mm256_sub_epi32(apb, M2);
    return _mm256_min_epu32(apb, ret);
  }
  __attribute__((target("avx2"), optimize("O3", "unroll-loops"))) friend vmint
  operator-(const vmint& A, const vmint& B) {
    m256 amb = _mm256_sub_epi32(A, B);
    m256 ret = _mm256_add_epi32(amb, M2);
    return _mm256_min_epu32(amb, ret);
  }
  __attribute__((target("avx2"), optimize("O3", "unroll-loops"))) friend vmint
  operator*(const vmint& A, const vmint& B) {
    m256 a13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 b13 = _mm256_shuffle_epi32(B, 0xF5);
    m256 prod02 = _mm256_mul_epu32(A, B);
    m256 prod13 = _mm256_mul_epu32(a13, b13);
    return reduce(prod02, prod13);
  }
  vmint& operator+=(const vmint& A) { return (*this) = (*this) + A; }
  vmint& operator-=(const vmint& A) { return (*this) = (*this) - A; }
  vmint& operator*=(const vmint& A) { return (*this) = (*this) * A; }
};
#line 2 "math-fast/vectorize-modint.hpp"

#include <immintrin.h>

#include <array>

using namespace std;

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2")

template <typename _mint>
struct alignas(32) vectorize_modint {
  using mint = _mint;
  using m256 = __m256i;
  using vmint = vectorize_modint;
  m256 x;
  inline static vmint R = mint::r;
  inline static vmint M0 = 0;
  inline static vmint M1 = mint::get_mod();
  inline static vmint M2 = mint::get_mod() * 2;
  inline static vmint N2 = mint::n2;
  vectorize_modint() = default;
  vectorize_modint(int a) : x(_mm256_set1_epi32(a)) {}
  vectorize_modint(const m256& _x) : x(_x) {}
  vectorize_modint(const array<int, 8>& a)
      : x(_mm256_loadu_si256((m256*)a.data())) {}
  vectorize_modint(int a0, int a1, int a2, int a3, int a4, int a5, int a6,
                   int a7)
      : x(_mm256_set_epi32(a7, a6, a5, a4, a3, a2, a1, a0)) {}
  int at(int i) const {
    /*
    alignas(32) array<int, 8> b;
    _mm256_store_si256((m256*)b.data(), x);
    return b[i];
    */
    return *(reinterpret_cast<const int*>(&x) + i);
  }
  void set(int i, int val) {
    /*
    alignas(32) array<int, 8> b;
    _mm256_store_si256((m256*)b.data(), x);
    b[i] = val;
    x = _mm256_load_si256((m256*)b.data());
    */
    *(reinterpret_cast<int*>(&x) + i) = val;
  }
  operator const __m256i&() const { return x; }
  friend ostream& operator<<(ostream& os, const vmint& m) {
    vmint a = mtoi(m);
    for (int i = 0; i < 8; i++) os << a.at(i) << (i == 7 ? "" : " ");
    return os;
  }
  static vmint reduce(const vmint& prod02, const vmint& prod13) {
    m256 unpalo = _mm256_unpacklo_epi32(prod02, prod13);
    m256 unpahi = _mm256_unpackhi_epi32(prod02, prod13);
    m256 prodlo = _mm256_unpacklo_epi64(unpalo, unpahi);
    m256 prodhi = _mm256_unpackhi_epi64(unpalo, unpahi);
    m256 hiplm1 = _mm256_add_epi32(prodhi, M1);
    m256 prodlohi = _mm256_shuffle_epi32(prodlo, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(prodlo, R);
    m256 lmlr13 = _mm256_mul_epu32(prodlohi, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    return _mm256_sub_epi32(hiplm1, prod);
  }
  static vmint itom(const vmint& A) { return A * N2; }
  static vmint mtoi(const vmint& A) {
    m256 A13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(A, R);
    m256 lmlr13 = _mm256_mul_epu32(A13, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    m256 cmp = _mm256_cmpgt_epi32(prod, M0);
    m256 dif = _mm256_and_si256(cmp, M1);
    return _mm256_sub_epi32(dif, prod);
  }
  __attribute__((target("avx2"), optimize("O3", "unroll-loops"))) friend vmint
  operator+(const vmint& A, const vmint& B) {
    m256 apb = _mm256_add_epi32(A, B);
    m256 ret = _mm256_sub_epi32(apb, M2);
    return _mm256_min_epu32(apb, ret);
  }
  __attribute__((target("avx2"), optimize("O3", "unroll-loops"))) friend vmint
  operator-(const vmint& A, const vmint& B) {
    m256 amb = _mm256_sub_epi32(A, B);
    m256 ret = _mm256_add_epi32(amb, M2);
    return _mm256_min_epu32(amb, ret);
  }
  __attribute__((target("avx2"), optimize("O3", "unroll-loops"))) friend vmint
  operator*(const vmint& A, const vmint& B) {
    m256 a13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 b13 = _mm256_shuffle_epi32(B, 0xF5);
    m256 prod02 = _mm256_mul_epu32(A, B);
    m256 prod13 = _mm256_mul_epu32(a13, b13);
    return reduce(prod02, prod13);
  }
  vmint& operator+=(const vmint& A) { return (*this) = (*this) + A; }
  vmint& operator-=(const vmint& A) { return (*this) = (*this) - A; }
  vmint& operator*=(const vmint& A) { return (*this) = (*this) * A; }
};
Back to top page