Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: modulo/gauss-elimination-fast.hpp

Depends on

Verified with

Code

#pragma once

#include "../modint/simd-montgomery.hpp"


namespace Gauss {
  
constexpr int MAT_SIZE = 4096;
uint32_t a_buf_[MAT_SIZE][MAT_SIZE] __attribute__((aligned(64)));

// return value: (rank, (-1) ^ (number of swap time))
template <typename mint>
__attribute__((target("avx2"))) pair<int, mint> GaussianElimination(
    const vector<vector<mint>> &m, int LinearEquation = false) {
  mint(&a)[MAT_SIZE][MAT_SIZE] = *reinterpret_cast<mint(*)[MAT_SIZE][MAT_SIZE]>(a_buf_);
  int H = m.size(), W = m[0].size(), rank = 0;
  mint det = 1;
  for (int i = 0; i < H; i++)
    for (int j = 0; j < W; j++) a[i][j].a = m[i][j].a;

  __m256i r = _mm256_set1_epi32(mint::r);
  __m256i m0 = _mm256_set1_epi32(0);
  __m256i m1 = _mm256_set1_epi32(mint::get_mod());
  __m256i m2 = _mm256_set1_epi32(mint::get_mod() << 1);

  for (int j = 0; j < (LinearEquation ? (W - 1) : W); j++) {
    // find basis
    if (rank == H) break;
    int idx = -1;
    for (int i = rank; i < H; i++) {
      if (a[i][j].get() != 0) {
        idx = i;
        break;
      }
    }
    if (idx == -1) {
      det = 0;
      continue;
    }

    // swap
    if (rank != idx) {
      det = -det;
      for (int l = j; l < W; l++) swap(a[rank][l], a[idx][l]);
    }
    det *= a[rank][j];

    // normalize
    if (LinearEquation) {
      if (a[rank][j].get() != 1) {
        mint coeff = a[rank][j].inverse();
        __m256i COEFF = _mm256_set1_epi32(coeff.a);
        for (int i = j / 8 * 8; i < W; i += 8) {
          __m256i R = _mm256_load_si256((__m256i *)(a[rank] + i));
          __m256i RmulC = montgomery_mul_256(R, COEFF, r, m1);
          _mm256_store_si256((__m256i *)(a[rank] + i), RmulC);
        }
      }
    }

    // elimination
    for (int k = (LinearEquation ? 0 : rank + 1); k < H; k++) {
      if (k == rank) continue;
      if (a[k][j].get() != 0) {
        mint coeff = a[k][j] / a[rank][j];
        __m256i COEFF = _mm256_set1_epi32(coeff.a);
        for (int i = j / 8 * 8; i < W; i += 8) {
          __m256i R = _mm256_load_si256((__m256i *)(a[rank] + i));
          __m256i K = _mm256_load_si256((__m256i *)(a[k] + i));
          __m256i RmulC = montgomery_mul_256(R, COEFF, r, m1);
          __m256i KmnsR = montgomery_sub_256(K, RmulC, m2, m0);
          _mm256_store_si256((__m256i *)(a[k] + i), KmnsR);
        }
      }
    }
    rank++;
  }
  return {rank, det};
}

// calculate determinant
template <typename mint>
mint determinant(const vector<vector<mint>> &mat) {
  return GaussianElimination(mat).second;
}

// return V<V<mint>>
// 0 column ... one of solutions
// 1 ~ (W - rank) column ... bases
// if not exist, return empty vector
template <typename mint>
vector<vector<mint>> LinearEquation(vector<vector<mint>> A, vector<mint> B) {
  int H = A.size(), W = A[0].size();
  for (int i = 0; i < H; i++) A[i].push_back(B[i]);

  auto p = GaussianElimination(A, true);

  mint(&a)[MAT_SIZE][MAT_SIZE] = *reinterpret_cast<mint(*)[MAT_SIZE][MAT_SIZE]>(a_buf_);
  int rank = p.first;

  // check if solutions exist
  for (int i = rank; i < H; ++i)
    if (a[i][W] != 0) return vector<vector<mint>>{};

  vector<vector<mint>> res(1, vector<mint>(W));
  vector<int> pivot(W, -1);
  for (int i = 0, j = 0; i < rank; ++i) {
    while (a[i][j] == 0) ++j;
    res[0][j] = a[i][W], pivot[j] = i;
  }
  for (int j = 0; j < W; ++j) {
    if (pivot[j] == -1) {
      vector<mint> x(W);
      x[j] = 1;
      for (int k = 0; k < j; ++k)
        if (pivot[k] != -1) x[k] = -a[pivot[k]][j];
      res.push_back(x);
    }
  }
  return res;
}

}  // namespace Gauss

using Gauss::determinant;
using Gauss::LinearEquation;
#line 2 "modulo/gauss-elimination-fast.hpp"

#line 2 "modint/simd-montgomery.hpp"

#include <immintrin.h>

__attribute__((target("sse4.2"))) inline __m128i my128_mullo_epu32(
    const __m128i &a, const __m128i &b) {
  return _mm_mullo_epi32(a, b);
}

__attribute__((target("sse4.2"))) inline __m128i my128_mulhi_epu32(
    const __m128i &a, const __m128i &b) {
  __m128i a13 = _mm_shuffle_epi32(a, 0xF5);
  __m128i b13 = _mm_shuffle_epi32(b, 0xF5);
  __m128i prod02 = _mm_mul_epu32(a, b);
  __m128i prod13 = _mm_mul_epu32(a13, b13);
  __m128i prod = _mm_unpackhi_epi64(_mm_unpacklo_epi32(prod02, prod13),
                                    _mm_unpackhi_epi32(prod02, prod13));
  return prod;
}

__attribute__((target("sse4.2"))) inline __m128i montgomery_mul_128(
    const __m128i &a, const __m128i &b, const __m128i &r, const __m128i &m1) {
  return _mm_sub_epi32(
      _mm_add_epi32(my128_mulhi_epu32(a, b), m1),
      my128_mulhi_epu32(my128_mullo_epu32(my128_mullo_epu32(a, b), r), m1));
}

__attribute__((target("sse4.2"))) inline __m128i montgomery_add_128(
    const __m128i &a, const __m128i &b, const __m128i &m2, const __m128i &m0) {
  __m128i ret = _mm_sub_epi32(_mm_add_epi32(a, b), m2);
  return _mm_add_epi32(_mm_and_si128(_mm_cmpgt_epi32(m0, ret), m2), ret);
}

__attribute__((target("sse4.2"))) inline __m128i montgomery_sub_128(
    const __m128i &a, const __m128i &b, const __m128i &m2, const __m128i &m0) {
  __m128i ret = _mm_sub_epi32(a, b);
  return _mm_add_epi32(_mm_and_si128(_mm_cmpgt_epi32(m0, ret), m2), ret);
}

__attribute__((target("avx2"))) inline __m256i my256_mullo_epu32(
    const __m256i &a, const __m256i &b) {
  return _mm256_mullo_epi32(a, b);
}

__attribute__((target("avx2"))) inline __m256i my256_mulhi_epu32(
    const __m256i &a, const __m256i &b) {
  __m256i a13 = _mm256_shuffle_epi32(a, 0xF5);
  __m256i b13 = _mm256_shuffle_epi32(b, 0xF5);
  __m256i prod02 = _mm256_mul_epu32(a, b);
  __m256i prod13 = _mm256_mul_epu32(a13, b13);
  __m256i prod = _mm256_unpackhi_epi64(_mm256_unpacklo_epi32(prod02, prod13),
                                       _mm256_unpackhi_epi32(prod02, prod13));
  return prod;
}

__attribute__((target("avx2"))) inline __m256i montgomery_mul_256(
    const __m256i &a, const __m256i &b, const __m256i &r, const __m256i &m1) {
  return _mm256_sub_epi32(
      _mm256_add_epi32(my256_mulhi_epu32(a, b), m1),
      my256_mulhi_epu32(my256_mullo_epu32(my256_mullo_epu32(a, b), r), m1));
}

__attribute__((target("avx2"))) inline __m256i montgomery_add_256(
    const __m256i &a, const __m256i &b, const __m256i &m2, const __m256i &m0) {
  __m256i ret = _mm256_sub_epi32(_mm256_add_epi32(a, b), m2);
  return _mm256_add_epi32(_mm256_and_si256(_mm256_cmpgt_epi32(m0, ret), m2),
                          ret);
}

__attribute__((target("avx2"))) inline __m256i montgomery_sub_256(
    const __m256i &a, const __m256i &b, const __m256i &m2, const __m256i &m0) {
  __m256i ret = _mm256_sub_epi32(a, b);
  return _mm256_add_epi32(_mm256_and_si256(_mm256_cmpgt_epi32(m0, ret), m2),
                          ret);
}
#line 4 "modulo/gauss-elimination-fast.hpp"


namespace Gauss {
  
constexpr int MAT_SIZE = 4096;
uint32_t a_buf_[MAT_SIZE][MAT_SIZE] __attribute__((aligned(64)));

// return value: (rank, (-1) ^ (number of swap time))
template <typename mint>
__attribute__((target("avx2"))) pair<int, mint> GaussianElimination(
    const vector<vector<mint>> &m, int LinearEquation = false) {
  mint(&a)[MAT_SIZE][MAT_SIZE] = *reinterpret_cast<mint(*)[MAT_SIZE][MAT_SIZE]>(a_buf_);
  int H = m.size(), W = m[0].size(), rank = 0;
  mint det = 1;
  for (int i = 0; i < H; i++)
    for (int j = 0; j < W; j++) a[i][j].a = m[i][j].a;

  __m256i r = _mm256_set1_epi32(mint::r);
  __m256i m0 = _mm256_set1_epi32(0);
  __m256i m1 = _mm256_set1_epi32(mint::get_mod());
  __m256i m2 = _mm256_set1_epi32(mint::get_mod() << 1);

  for (int j = 0; j < (LinearEquation ? (W - 1) : W); j++) {
    // find basis
    if (rank == H) break;
    int idx = -1;
    for (int i = rank; i < H; i++) {
      if (a[i][j].get() != 0) {
        idx = i;
        break;
      }
    }
    if (idx == -1) {
      det = 0;
      continue;
    }

    // swap
    if (rank != idx) {
      det = -det;
      for (int l = j; l < W; l++) swap(a[rank][l], a[idx][l]);
    }
    det *= a[rank][j];

    // normalize
    if (LinearEquation) {
      if (a[rank][j].get() != 1) {
        mint coeff = a[rank][j].inverse();
        __m256i COEFF = _mm256_set1_epi32(coeff.a);
        for (int i = j / 8 * 8; i < W; i += 8) {
          __m256i R = _mm256_load_si256((__m256i *)(a[rank] + i));
          __m256i RmulC = montgomery_mul_256(R, COEFF, r, m1);
          _mm256_store_si256((__m256i *)(a[rank] + i), RmulC);
        }
      }
    }

    // elimination
    for (int k = (LinearEquation ? 0 : rank + 1); k < H; k++) {
      if (k == rank) continue;
      if (a[k][j].get() != 0) {
        mint coeff = a[k][j] / a[rank][j];
        __m256i COEFF = _mm256_set1_epi32(coeff.a);
        for (int i = j / 8 * 8; i < W; i += 8) {
          __m256i R = _mm256_load_si256((__m256i *)(a[rank] + i));
          __m256i K = _mm256_load_si256((__m256i *)(a[k] + i));
          __m256i RmulC = montgomery_mul_256(R, COEFF, r, m1);
          __m256i KmnsR = montgomery_sub_256(K, RmulC, m2, m0);
          _mm256_store_si256((__m256i *)(a[k] + i), KmnsR);
        }
      }
    }
    rank++;
  }
  return {rank, det};
}

// calculate determinant
template <typename mint>
mint determinant(const vector<vector<mint>> &mat) {
  return GaussianElimination(mat).second;
}

// return V<V<mint>>
// 0 column ... one of solutions
// 1 ~ (W - rank) column ... bases
// if not exist, return empty vector
template <typename mint>
vector<vector<mint>> LinearEquation(vector<vector<mint>> A, vector<mint> B) {
  int H = A.size(), W = A[0].size();
  for (int i = 0; i < H; i++) A[i].push_back(B[i]);

  auto p = GaussianElimination(A, true);

  mint(&a)[MAT_SIZE][MAT_SIZE] = *reinterpret_cast<mint(*)[MAT_SIZE][MAT_SIZE]>(a_buf_);
  int rank = p.first;

  // check if solutions exist
  for (int i = rank; i < H; ++i)
    if (a[i][W] != 0) return vector<vector<mint>>{};

  vector<vector<mint>> res(1, vector<mint>(W));
  vector<int> pivot(W, -1);
  for (int i = 0, j = 0; i < rank; ++i) {
    while (a[i][j] == 0) ++j;
    res[0][j] = a[i][W], pivot[j] = i;
  }
  for (int j = 0; j < W; ++j) {
    if (pivot[j] == -1) {
      vector<mint> x(W);
      x[j] = 1;
      for (int k = 0; k < j; ++k)
        if (pivot[k] != -1) x[k] = -a[pivot[k]][j];
      res.push_back(x);
    }
  }
  return res;
}

}  // namespace Gauss

using Gauss::determinant;
using Gauss::LinearEquation;
Back to top page