Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: verify/verify-yosupo-math/yosupo-matrix-product-vectorize-modint.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
//
#include <immintrin.h>
//
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <random>
#include <type_traits>
#include <utility>
#include <vector>

using namespace std;

#include "../../misc/fastio.hpp"
//
#include "../../modint/montgomery-modint.hpp"
//
#include "../../math-fast/mat-prod-strassen.hpp"
#include "../../modint/vectorize-modint.hpp"

int main() {
  using mint = LazyMontgomeryModInt<998244353>;
  mmint::set_mod<mint>();

  using namespace fast_mat_prod_impl;

#ifdef PROFILER
  {
    unsigned int* a = reinterpret_cast<unsigned int*>(t);
    unsigned int* b = reinterpret_cast<unsigned int*>(u);
    unsigned int* c = reinterpret_cast<unsigned int*>(s);
    for (int i = 0; i < S; i++) {
      for (int j = 0; j < S; j++) {
        b[i * S + j] = a[i * S + j] = i + j;
      }
    }
    for (int loop = 0; loop < 100; loop++) prod(a, b, c);
    return 0;
  }
#endif

  int N, M, K;
  rd(N, M, K);

  unsigned int* a = reinterpret_cast<unsigned int*>(t);
  unsigned int* b = reinterpret_cast<unsigned int*>(u);
  unsigned int* c = reinterpret_cast<unsigned int*>(s);
  unsigned int x;

  for (int i = 0; i < N; i++) {
    for (int k = 0; k < M; k++) {
      rd(x);
      a[i * S + k] = x;
    }
  }
  for (int k = 0; k < M; k++) {
    for (int j = 0; j < K; j++) {
      rd(x);
      b[k * S + j] = x;
    }
  }

  prod(a, b, c);

  for (int i = 0; i < N; i++) {
    for (int j = 0; j < K; j++) {
      x = c[i * S + j];
      wt(x);
      wt(j == K - 1 ? '\n' : ' ');
    }
  }
}
#line 1 "verify/verify-yosupo-math/yosupo-matrix-product-vectorize-modint.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
//
#include <immintrin.h>
//
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <random>
#include <type_traits>
#include <utility>
#include <vector>

using namespace std;

#line 2 "misc/fastio.hpp"

#line 5 "misc/fastio.hpp"
#include <string>
#line 8 "misc/fastio.hpp"

using namespace std;

#line 2 "internal/internal-type-traits.hpp"

#line 4 "internal/internal-type-traits.hpp"
using namespace std;

namespace internal {
template <typename T>
using is_broadly_integral =
    typename conditional_t<is_integral_v<T> || is_same_v<T, __int128_t> ||
                               is_same_v<T, __uint128_t>,
                           true_type, false_type>::type;

template <typename T>
using is_broadly_signed =
    typename conditional_t<is_signed_v<T> || is_same_v<T, __int128_t>,
                           true_type, false_type>::type;

template <typename T>
using is_broadly_unsigned =
    typename conditional_t<is_unsigned_v<T> || is_same_v<T, __uint128_t>,
                           true_type, false_type>::type;

#define ENABLE_VALUE(x) \
  template <typename T> \
  constexpr bool x##_v = x<T>::value;

ENABLE_VALUE(is_broadly_integral);
ENABLE_VALUE(is_broadly_signed);
ENABLE_VALUE(is_broadly_unsigned);
#undef ENABLE_VALUE

#define ENABLE_HAS_TYPE(var)                                   \
  template <class, class = void>                               \
  struct has_##var : false_type {};                            \
  template <class T>                                           \
  struct has_##var<T, void_t<typename T::var>> : true_type {}; \
  template <class T>                                           \
  constexpr auto has_##var##_v = has_##var<T>::value;

#define ENABLE_HAS_VAR(var)                                     \
  template <class, class = void>                                \
  struct has_##var : false_type {};                             \
  template <class T>                                            \
  struct has_##var<T, void_t<decltype(T::var)>> : true_type {}; \
  template <class T>                                            \
  constexpr auto has_##var##_v = has_##var<T>::value;

}  // namespace internal
#line 12 "misc/fastio.hpp"

namespace fastio {
static constexpr int SZ = 1 << 17;
static constexpr int offset = 64;
char inbuf[SZ], outbuf[SZ];
int in_left = 0, in_right = 0, out_right = 0;

struct Pre {
  char num[40000];
  constexpr Pre() : num() {
    for (int i = 0; i < 10000; i++) {
      int n = i;
      for (int j = 3; j >= 0; j--) {
        num[i * 4 + j] = n % 10 + '0';
        n /= 10;
      }
    }
  }
} constexpr pre;

void load() {
  int len = in_right - in_left;
  memmove(inbuf, inbuf + in_left, len);
  in_right = len + fread(inbuf + len, 1, SZ - len, stdin);
  in_left = 0;
}
void flush() {
  fwrite(outbuf, 1, out_right, stdout);
  out_right = 0;
}
void skip_space() {
  if (in_left + offset > in_right) load();
  while (inbuf[in_left] <= ' ') in_left++;
}

void single_read(char& c) {
  if (in_left + offset > in_right) load();
  skip_space();
  c = inbuf[in_left++];
}
void single_read(string& S) {
  skip_space();
  while (true) {
    if (in_left == in_right) load();
    int i = in_left;
    for (; i != in_right; i++) {
      if (inbuf[i] <= ' ') break;
    }
    copy(inbuf + in_left, inbuf + i, back_inserter(S));
    in_left = i;
    if (i != in_right) break;
  }
}
template <typename T,
          enable_if_t<internal::is_broadly_integral_v<T>>* = nullptr>
void single_read(T& x) {
  if (in_left + offset > in_right) load();
  skip_space();
  char c = inbuf[in_left++];
  [[maybe_unused]] bool minus = false;
  if constexpr (internal::is_broadly_signed_v<T>) {
    if (c == '-') minus = true, c = inbuf[in_left++];
  }
  x = 0;
  while (c >= '0') {
    x = x * 10 + (c & 15);
    c = inbuf[in_left++];
  }
  if constexpr (internal::is_broadly_signed_v<T>) {
    if (minus) x = -x;
  }
}
void rd() {}
template <typename Head, typename... Tail>
void rd(Head& head, Tail&... tail) {
  single_read(head);
  rd(tail...);
}

void single_write(const char& c) {
  if (out_right > SZ - offset) flush();
  outbuf[out_right++] = c;
}
void single_write(const bool& b) {
  if (out_right > SZ - offset) flush();
  outbuf[out_right++] = b ? '1' : '0';
}
void single_write(const string& S) {
  flush(), fwrite(S.data(), 1, S.size(), stdout);
}
void single_write(const char* p) { flush(), fwrite(p, 1, strlen(p), stdout); }
template <typename T,
          enable_if_t<internal::is_broadly_integral_v<T>>* = nullptr>
void single_write(const T& _x) {
  if (out_right > SZ - offset) flush();
  if (_x == 0) {
    outbuf[out_right++] = '0';
    return;
  }
  T x = _x;
  if constexpr (internal::is_broadly_signed_v<T>) {
    if (x < 0) outbuf[out_right++] = '-', x = -x;
  }
  constexpr int buffer_size = sizeof(T) * 10 / 4;
  char buf[buffer_size];
  int i = buffer_size;
  while (x >= 10000) {
    i -= 4;
    memcpy(buf + i, pre.num + (x % 10000) * 4, 4);
    x /= 10000;
  }
  if (x < 100) {
    if (x < 10) {
      outbuf[out_right] = '0' + x;
      ++out_right;
    } else {
      uint32_t q = (uint32_t(x) * 205) >> 11;
      uint32_t r = uint32_t(x) - q * 10;
      outbuf[out_right] = '0' + q;
      outbuf[out_right + 1] = '0' + r;
      out_right += 2;
    }
  } else {
    if (x < 1000) {
      memcpy(outbuf + out_right, pre.num + (x << 2) + 1, 3);
      out_right += 3;
    } else {
      memcpy(outbuf + out_right, pre.num + (x << 2), 4);
      out_right += 4;
    }
  }
  memcpy(outbuf + out_right, buf + i, buffer_size - i);
  out_right += buffer_size - i;
}
void wt() {}
template <typename Head, typename... Tail>
void wt(const Head& head, const Tail&... tail) {
  single_write(head);
  wt(forward<const Tail>(tail)...);
}
template <typename... Args>
void wtn(const Args&... x) {
  wt(forward<const Args>(x)...);
  wt('\n');
}

struct Dummy {
  Dummy() { atexit(flush); }
} dummy;

}  // namespace fastio
using fastio::rd;
using fastio::skip_space;
using fastio::wt;
using fastio::wtn;
#line 20 "verify/verify-yosupo-math/yosupo-matrix-product-vectorize-modint.test.cpp"
//
#line 2 "modint/montgomery-modint.hpp"

template <uint32_t mod>
struct LazyMontgomeryModInt {
  using mint = LazyMontgomeryModInt;
  using i32 = int32_t;
  using u32 = uint32_t;
  using u64 = uint64_t;

  static constexpr u32 get_r() {
    u32 ret = mod;
    for (i32 i = 0; i < 4; ++i) ret *= 2 - mod * ret;
    return ret;
  }

  static constexpr u32 r = get_r();
  static constexpr u32 n2 = -u64(mod) % mod;
  static_assert(mod < (1 << 30), "invalid, mod >= 2 ^ 30");
  static_assert((mod & 1) == 1, "invalid, mod % 2 == 0");
  static_assert(r * mod == 1, "this code has bugs.");

  u32 a;

  constexpr LazyMontgomeryModInt() : a(0) {}
  constexpr LazyMontgomeryModInt(const int64_t &b)
      : a(reduce(u64(b % mod + mod) * n2)){};

  static constexpr u32 reduce(const u64 &b) {
    return (b + u64(u32(b) * u32(-r)) * mod) >> 32;
  }

  constexpr mint &operator+=(const mint &b) {
    if (i32(a += b.a - 2 * mod) < 0) a += 2 * mod;
    return *this;
  }

  constexpr mint &operator-=(const mint &b) {
    if (i32(a -= b.a) < 0) a += 2 * mod;
    return *this;
  }

  constexpr mint &operator*=(const mint &b) {
    a = reduce(u64(a) * b.a);
    return *this;
  }

  constexpr mint &operator/=(const mint &b) {
    *this *= b.inverse();
    return *this;
  }

  constexpr mint operator+(const mint &b) const { return mint(*this) += b; }
  constexpr mint operator-(const mint &b) const { return mint(*this) -= b; }
  constexpr mint operator*(const mint &b) const { return mint(*this) *= b; }
  constexpr mint operator/(const mint &b) const { return mint(*this) /= b; }
  constexpr bool operator==(const mint &b) const {
    return (a >= mod ? a - mod : a) == (b.a >= mod ? b.a - mod : b.a);
  }
  constexpr bool operator!=(const mint &b) const {
    return (a >= mod ? a - mod : a) != (b.a >= mod ? b.a - mod : b.a);
  }
  constexpr mint operator-() const { return mint() - mint(*this); }
  constexpr mint operator+() const { return mint(*this); }

  constexpr mint pow(u64 n) const {
    mint ret(1), mul(*this);
    while (n > 0) {
      if (n & 1) ret *= mul;
      mul *= mul;
      n >>= 1;
    }
    return ret;
  }

  constexpr mint inverse() const {
    int x = get(), y = mod, u = 1, v = 0, t = 0, tmp = 0;
    while (y > 0) {
      t = x / y;
      x -= t * y, u -= t * v;
      tmp = x, x = y, y = tmp;
      tmp = u, u = v, v = tmp;
    }
    return mint{u};
  }

  friend ostream &operator<<(ostream &os, const mint &b) {
    return os << b.get();
  }

  friend istream &operator>>(istream &is, mint &b) {
    int64_t t;
    is >> t;
    b = LazyMontgomeryModInt<mod>(t);
    return (is);
  }

  constexpr u32 get() const {
    u32 ret = reduce(a);
    return ret >= mod ? ret - mod : ret;
  }

  static constexpr u32 get_mod() { return mod; }
};
#line 22 "verify/verify-yosupo-math/yosupo-matrix-product-vectorize-modint.test.cpp"
//
#line 2 "math-fast/mat-prod-strassen.hpp"

#line 2 "modint/vectorize-modint.hpp"

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2")

#line 8 "modint/vectorize-modint.hpp"
using namespace std;

using m256 = __m256i;
struct alignas(32) mmint {
  m256 x;
  static mmint R, M0, M1, M2, N2;

  mmint() : x() {}
  inline mmint(const m256& _x) : x(_x) {}
  inline mmint(unsigned int a) : x(_mm256_set1_epi32(a)) {}
  inline mmint(unsigned int a0, unsigned int a1, unsigned int a2,
               unsigned int a3, unsigned int a4, unsigned int a5,
               unsigned int a6, unsigned int a7)
      : x(_mm256_set_epi32(a7, a6, a5, a4, a3, a2, a1, a0)) {}
  inline operator m256&() { return x; }
  inline operator const m256&() const { return x; }
  inline int& operator[](int i) { return *(reinterpret_cast<int*>(&x) + i); }
  inline const int& operator[](int i) const {
    return *(reinterpret_cast<const int*>(&x) + i);
  }

  friend ostream& operator<<(ostream& os, const mmint& m) {
    unsigned r = R[0], mod = M1[0];
    auto reduce1 = [&](const uint64_t& b) {
      unsigned res = (b + uint64_t(unsigned(b) * unsigned(-r)) * mod) >> 32;
      return res >= mod ? res - mod : res;
    };
    for (int i = 0; i < 8; i++) {
      os << reduce1(m[i]) << (i == 7 ? "" : " ");
    }
    return os;
  }

  template <typename mint>
  static void set_mod() {
    R = _mm256_set1_epi32(mint::r);
    M0 = _mm256_setzero_si256();
    M1 = _mm256_set1_epi32(mint::get_mod());
    M2 = _mm256_set1_epi32(mint::get_mod() * 2);
    N2 = _mm256_set1_epi32(mint::n2);
  }

  static inline mmint reduce(const mmint& prod02, const mmint& prod13) {
    m256 unpalo = _mm256_unpacklo_epi32(prod02, prod13);
    m256 unpahi = _mm256_unpackhi_epi32(prod02, prod13);
    m256 prodlo = _mm256_unpacklo_epi64(unpalo, unpahi);
    m256 prodhi = _mm256_unpackhi_epi64(unpalo, unpahi);
    m256 hiplm1 = _mm256_add_epi32(prodhi, M1);
    m256 prodlohi = _mm256_shuffle_epi32(prodlo, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(prodlo, R);
    m256 lmlr13 = _mm256_mul_epu32(prodlohi, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    return _mm256_sub_epi32(hiplm1, prod);
  }

  static inline mmint itom(const mmint& A) { return A * N2; }

  static inline mmint mtoi(const mmint& A) {
    m256 A13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 lmlr02 = _mm256_mul_epu32(A, R);
    m256 lmlr13 = _mm256_mul_epu32(A13, R);
    m256 prod02_ = _mm256_mul_epu32(lmlr02, M1);
    m256 prod13_ = _mm256_mul_epu32(lmlr13, M1);
    m256 unpalo_ = _mm256_unpacklo_epi32(prod02_, prod13_);
    m256 unpahi_ = _mm256_unpackhi_epi32(prod02_, prod13_);
    m256 prod = _mm256_unpackhi_epi64(unpalo_, unpahi_);
    m256 cmp = _mm256_cmpgt_epi32(prod, M0);
    m256 dif = _mm256_and_si256(cmp, M1);
    return _mm256_sub_epi32(dif, prod);
  }

  friend inline mmint operator+(const mmint& A, const mmint& B) {
    m256 apb = _mm256_add_epi32(A, B);
    m256 ret = _mm256_sub_epi32(apb, M2);
    m256 cmp = _mm256_cmpgt_epi32(M0, ret);
    m256 add = _mm256_and_si256(cmp, M2);
    return _mm256_add_epi32(add, ret);
  }

  friend inline mmint operator-(const mmint& A, const mmint& B) {
    m256 ret = _mm256_sub_epi32(A, B);
    m256 cmp = _mm256_cmpgt_epi32(M0, ret);
    m256 add = _mm256_and_si256(cmp, M2);
    return _mm256_add_epi32(add, ret);
  }

  friend inline mmint operator*(const mmint& A, const mmint& B) {
    m256 a13 = _mm256_shuffle_epi32(A, 0xF5);
    m256 b13 = _mm256_shuffle_epi32(B, 0xF5);
    m256 prod02 = _mm256_mul_epu32(A, B);
    m256 prod13 = _mm256_mul_epu32(a13, b13);
    return reduce(prod02, prod13);
  }

  inline mmint& operator+=(const mmint& A) { return (*this) = (*this) + A; }
  inline mmint& operator-=(const mmint& A) { return (*this) = (*this) - A; }
  inline mmint& operator*=(const mmint& A) { return (*this) = (*this) * A; }

  bool operator==(const mmint& A) {
    m256 sub = _mm256_sub_epi32(x, A.x);
    return _mm256_testz_si256(sub, sub) == 1;
  }
  bool operator!=(const mmint& A) { return !((*this) == A); }
};
__attribute__((aligned(32))) mmint mmint::R;
__attribute__((aligned(32))) mmint mmint::M0, mmint::M1, mmint::M2, mmint::N2;

/**
 * @brief vectorize modint
 */
#line 4 "math-fast/mat-prod-strassen.hpp"

// B*Bの正方行列を高速に乗算するライブラリ。
// B*B行列a,bを タテB行 ヨコB/8行の行列と見なす.
// s : 正順に配置。すなわちa_{i,k}をs[i * (B / 8) + k]に配置する。
// t : 逆順に配置。すなわちb_{k,j}をt[j * B + k]に配置する。
// u : 正順に配置。すなわちc_{i,j}をu[i * (B / 8) + j]に配置する。
namespace fast_mat_prod_impl {
constexpr int B = 1 << 7;
constexpr int B8 = B / 8;

void mul_simd(mmint* __restrict__ s, mmint* __restrict__ t,
              mmint* __restrict__ u) {
  for (int i = 0; i < B * B8; i++) {
    const m256 cmpS = _mm256_cmpgt_epi32(s[i], mmint::M1);
    const m256 cmpT = _mm256_cmpgt_epi32(t[i], mmint::M1);
    const m256 difS = _mm256_and_si256(cmpS, mmint::M1);
    const m256 difT = _mm256_and_si256(cmpT, mmint::M1);
    s[i] = _mm256_sub_epi32(s[i], difS);
    t[i] = _mm256_sub_epi32(t[i], difT);
  }

  mmint th1, th2, zero = _mm256_setzero_si256();
  th1[1] = th1[3] = th1[5] = th1[7] = mmint::M1[0];
  th2[1] = th2[3] = th2[5] = th2[7] = mmint::M2[0];

#define INIT_X(x, y)                          \
  m256 prod02##x##y = _mm256_setzero_si256(); \
  m256 prod13##x##y = _mm256_setzero_si256()

#define INIT_Y(j, k, l, y)            \
  m256 T##y = t[(j + y) * B + k + l]; \
  const m256 T13##y = _mm256_shuffle_epi32(T##y, 0xF5);

#define PROD(x, y)                                              \
  m256 S##x##y = _mm256_set1_epi32(s[(i + x) * B8 + k / 8][l]); \
  const m256 ST02##x##y = _mm256_mul_epu32(S##x##y, T##y);      \
  const m256 ST13##x##y = _mm256_mul_epu32(S##x##y, T13##y);    \
  prod02##x##y = _mm256_add_epi64(prod02##x##y, ST02##x##y);    \
  prod13##x##y = _mm256_add_epi64(prod13##x##y, ST13##x##y)

#define COMP(x, y)                                            \
  m256 cmp02##x##y = _mm256_cmpgt_epi64(zero, prod02##x##y);  \
  m256 cmp13##x##y = _mm256_cmpgt_epi64(zero, prod13##x##y);  \
  m256 dif02##x##y = _mm256_and_si256(cmp02##x##y, th2);      \
  m256 dif13##x##y = _mm256_and_si256(cmp13##x##y, th2);      \
  prod02##x##y = _mm256_sub_epi64(prod02##x##y, dif02##x##y); \
  prod13##x##y = _mm256_sub_epi64(prod13##x##y, dif13##x##y)

#define REDUCE(x, y)                                      \
  for (int _ = 0; _ < 2; _++) {                           \
    m256 cmp02 = _mm256_cmpgt_epi64(prod02##x##y, th1);   \
    m256 cmp13 = _mm256_cmpgt_epi64(prod13##x##y, th1);   \
    m256 dif02 = _mm256_and_si256(cmp02, th1);            \
    m256 dif13 = _mm256_and_si256(cmp13, th1);            \
    prod02##x##y = _mm256_sub_epi64(prod02##x##y, dif02); \
    prod13##x##y = _mm256_sub_epi64(prod13##x##y, dif13); \
  }                                                       \
  u[(i + x) * B8 + j + y] = mmint::reduce(prod02##x##y, prod13##x##y)

  for (int i = 0; i < B; i += 8) {
    for (int j = 0; j < B8; j += 1) {
      INIT_X(0, 0);
      INIT_X(1, 0);
      INIT_X(2, 0);
      INIT_X(3, 0);
      INIT_X(4, 0);
      INIT_X(5, 0);
      INIT_X(6, 0);
      INIT_X(7, 0);
      for (int k = 0; k < B; k += 8) {
        for (int l = 0; l < 8; l++) {
          INIT_Y(j, k, l, 0);
          PROD(0, 0);
          PROD(1, 0);
          PROD(2, 0);
          PROD(3, 0);
          PROD(4, 0);
          PROD(5, 0);
          PROD(6, 0);
          PROD(7, 0);
        }
        COMP(0, 0);
        COMP(1, 0);
        COMP(2, 0);
        COMP(3, 0);
        COMP(4, 0);
        COMP(5, 0);
        COMP(6, 0);
        COMP(7, 0);
      }
      REDUCE(0, 0);
      REDUCE(1, 0);
      REDUCE(2, 0);
      REDUCE(3, 0);
      REDUCE(4, 0);
      REDUCE(5, 0);
      REDUCE(6, 0);
      REDUCE(7, 0);
    }
  }
}

#undef INIT
#undef PROD
#undef COMP
#undef REDUCE

void strassen(int N, mmint* __restrict__ s, mmint* __restrict__ t,
              mmint* __restrict__ u) {
  for (int i = 0; i < N * N / 8; i++) u[i] = mmint::M0;

  if (N == B) {
    mul_simd(s, t, u);
    return;
  }

  mmint* ps = s + N * N / 8;
  mmint* pt = t + N * N / 8;
  mmint* pu = u + N * N / 8;
  int nx = N * N / 32;
  int o11 = nx * 0, o12 = nx * 1, o21 = nx * 2, o22 = nx * 3;

  // P1
  for (int i = 0; i < nx; i++) ps[i] = s[o11 + i] + s[o22 + i];
  for (int i = 0; i < nx; i++) pt[i] = t[o11 + i] + t[o22 + i];
  strassen(N / 2, ps, pt, pu);
  for (int i = 0; i < nx; i++) u[o11 + i] = pu[i], u[o22 + i] = pu[i];

  // P2
  for (int i = 0; i < nx; i++) ps[i] = s[o21 + i] + s[o22 + i];
  for (int i = 0; i < nx; i++) pt[i] = t[o11 + i];
  strassen(N / 2, ps, pt, pu);
  for (int i = 0; i < nx; i++) u[o21 + i] = pu[i], u[o22 + i] -= pu[i];

  // P3
  for (int i = 0; i < nx; i++) ps[i] = s[o11 + i];
  for (int i = 0; i < nx; i++) pt[i] = t[o12 + i] - t[o22 + i];
  strassen(N / 2, ps, pt, pu);
  for (int i = 0; i < nx; i++) u[o12 + i] = pu[i], u[o22 + i] += pu[i];

  // P4
  for (int i = 0; i < nx; i++) ps[i] = s[o22 + i];
  for (int i = 0; i < nx; i++) pt[i] = t[o21 + i] - t[o11 + i];
  strassen(N / 2, ps, pt, pu);
  for (int i = 0; i < nx; i++) u[o11 + i] += pu[i], u[o21 + i] += pu[i];

  // P5
  for (int i = 0; i < nx; i++) ps[i] = s[o11 + i] + s[o12 + i];
  for (int i = 0; i < nx; i++) pt[i] = t[o22 + i];
  strassen(N / 2, ps, pt, pu);
  for (int i = 0; i < nx; i++) u[o11 + i] -= pu[i], u[o12 + i] += pu[i];

  // P6
  for (int i = 0; i < nx; i++) ps[i] = s[o21 + i] - s[o11 + i];
  for (int i = 0; i < nx; i++) pt[i] = t[o11 + i] + t[o12 + i];
  strassen(N / 2, ps, pt, pu);
  for (int i = 0; i < nx; i++) u[o22 + i] += pu[i];

  // P7
  for (int i = 0; i < nx; i++) ps[i] = s[o12 + i] - s[o22 + i];
  for (int i = 0; i < nx; i++) pt[i] = t[o21 + i] + t[o22 + i];
  strassen(N / 2, ps, pt, pu);
  for (int i = 0; i < nx; i++) u[o11 + i] += pu[i];
}

constexpr int S = 1024;
constexpr int S8 = S / 8;
mmint s[S * S8 * 3 / 2], t[S * S8 * 3 / 2], u[S * S8 * 3 / 2];

void place_s(int N, int a, int b, mmint* __restrict__ dst,
             mmint* __restrict__ src) {
  if (N == B) {
    for (int i = 0; i < B; i++) {
      memcpy(dst + i * B8, src + (a + i) * S8 + b / 8, B8 * sizeof(mmint));
    }
    return;
  }
  int nx = N * N / 32, M = N / 2;
  place_s(M, a + 0, b + 0, dst + nx * 0, src);
  place_s(M, a + 0, b + M, dst + nx * 1, src);
  place_s(M, a + M, b + 0, dst + nx * 2, src);
  place_s(M, a + M, b + M, dst + nx * 3, src);
}

void place_t(int N, int a, int b, mmint* __restrict__ dst,
             mmint* __restrict__ src) {
  if (N == B) {
    // t : 逆順に配置。すなわちb_{k,j}をt[j * B + k]に配置する。
    for (int k = 0; k < B; k++) {
      for (int j = 0; j < B8; j++) {
        dst[j * B + k] = src[(a + k) * S8 + j + b / 8];
      }
    }
    return;
  }
  int nx = N * N / 32, M = N / 2;
  place_t(M, a + 0, b + 0, dst + nx * 0, src);
  place_t(M, a + 0, b + M, dst + nx * 1, src);
  place_t(M, a + M, b + 0, dst + nx * 2, src);
  place_t(M, a + M, b + M, dst + nx * 3, src);
}

void place_rev(int N, int a, int b, mmint* __restrict__ dst,
               mmint* __restrict__ src) {
  if (N == B) {
    for (int i = 0; i < B; i++) {
      memcpy(src + (a + i) * S8 + b / 8, dst + i * B8, B8 * sizeof(mmint));
    }
    return;
  }
  int nx = N * N / 32, M = N / 2;
  place_rev(M, a + 0, b + 0, dst + nx * 0, src);
  place_rev(M, a + 0, b + M, dst + nx * 1, src);
  place_rev(M, a + M, b + 0, dst + nx * 2, src);
  place_rev(M, a + M, b + M, dst + nx * 3, src);
}

void prod(unsigned int* __restrict__ a, unsigned int* __restrict__ b,
          unsigned int* __restrict__ c) {
  place_s(S, 0, 0, s, reinterpret_cast<mmint*>(a));
  place_t(S, 0, 0, t, reinterpret_cast<mmint*>(b));
  for (int i = 0; i < S * S8; i++) s[i] = mmint::itom(s[i]);
  for (int i = 0; i < S * S8; i++) t[i] = mmint::itom(t[i]);
  strassen(S, s, t, u);
  for (int i = 0; i < S * S8; i++) u[i] = mmint::mtoi(u[i]);
  place_rev(S, 0, 0, u, reinterpret_cast<mmint*>(c));
}

}  // namespace fast_mat_prod_impl
#line 25 "verify/verify-yosupo-math/yosupo-matrix-product-vectorize-modint.test.cpp"

int main() {
  using mint = LazyMontgomeryModInt<998244353>;
  mmint::set_mod<mint>();

  using namespace fast_mat_prod_impl;

#ifdef PROFILER
  {
    unsigned int* a = reinterpret_cast<unsigned int*>(t);
    unsigned int* b = reinterpret_cast<unsigned int*>(u);
    unsigned int* c = reinterpret_cast<unsigned int*>(s);
    for (int i = 0; i < S; i++) {
      for (int j = 0; j < S; j++) {
        b[i * S + j] = a[i * S + j] = i + j;
      }
    }
    for (int loop = 0; loop < 100; loop++) prod(a, b, c);
    return 0;
  }
#endif

  int N, M, K;
  rd(N, M, K);

  unsigned int* a = reinterpret_cast<unsigned int*>(t);
  unsigned int* b = reinterpret_cast<unsigned int*>(u);
  unsigned int* c = reinterpret_cast<unsigned int*>(s);
  unsigned int x;

  for (int i = 0; i < N; i++) {
    for (int k = 0; k < M; k++) {
      rd(x);
      a[i * S + k] = x;
    }
  }
  for (int k = 0; k < M; k++) {
    for (int j = 0; j < K; j++) {
      rd(x);
      b[k * S + j] = x;
    }
  }

  prod(a, b, c);

  for (int i = 0; i < N; i++) {
    for (int j = 0; j < K; j++) {
      x = c[i * S + j];
      wt(x);
      wt(j == K - 1 ? '\n' : ' ');
    }
  }
}
Back to top page