Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: Schönhage-Strassen Algorithm(radix-2)
(ntt/schoenhage-strassen-radix2.hpp)

Verified with

Code

#pragma once

template <typename T>
struct Schoenhage_Strassen_radix2 {
  T* buf = nullptr;

  void rot(T* dst, T* src, int s, int d) {
    assert(0 <= d and d < 2 * s);
    bool f = s <= d;
    if (s <= d) d -= s;
    int i = 0;
    if (f) {
      for (; i < s - d; i++) dst[i + d] = -src[i];
      for (; i < s; i++) dst[i + d - s] = src[i];
    } else {
      for (; i < s - d; i++) dst[i + d] = src[i];
      for (; i < s; i++) dst[i + d - s] = -src[i];
    }
  }

  void in_add(T* dst, T* src, int s) {
    for (int i = 0; i < s; i++) dst[i] += src[i];
  }
  void in_sub(T* dst, T* src, int s) {
    for (int i = 0; i < s; i++) dst[i] -= src[i];
  }

  void cp(T* dst, T* src, int s) { memcpy(dst, src, s * sizeof(T)); }
  void reset(T* dst, int s) { fill(dst, dst + s, T()); }

  // R[x] / (1 + x^(2^m)) 上の長さ2^LのFFT
  void fft(T* a, int l, int m) {
    if (l == 0) return;
    int L = 1 << l, M = 1 << m;
    assert(M * 2 >= L);
    assert(buf != nullptr);

    vector<int> dw(l - 1);
    for (int i = 0; i < l - 1; i++) {
      dw[i] = (1 << (l - 2 - i)) + (1 << (l - 1 - i)) - (1 << (l - 1));
      if (dw[i] < 0) dw[i] += L;
      if (L == M) dw[i] *= 2;
      if (2 * L == M) dw[i] *= 4;
    }

    for (int d = L; d >>= 1;) {
      int w = 0;
      for (int s = 0, k = 0;;) {
        for (int i = s, j = s + d; i < s + d; ++i, ++j) {
          T *ai = a + i * M, *aj = a + j * M;
          rot(buf, aj, M, w);
          cp(aj, ai, M);
          in_add(ai, buf, M);
          in_sub(aj, buf, M);
        }
        if ((s += 2 * d) >= L) break;
        w += dw[__builtin_ctz(++k)];
        if (w >= 2 * M) w -= 2 * M;
      }
    }
  }

  // R[x] / (1 + x^(2^m)) 上の長さ2^LのIFFT
  void ifft(T* a, int l, int m) {
    if (l == 0) return;
    int L = 1 << l, M = 1 << m;
    assert(M * 2 >= L);
    assert(buf != nullptr);

    vector<int> dw(l - 1);
    for (int i = 0; i < l - 1; i++) {
      dw[i] = (1 << (l - 2 - i)) + (1 << (l - 1 - i)) - (1 << (l - 1));
      if (dw[i] < 0) dw[i] += L;
      if (L == M) dw[i] *= 2;
      if (2 * L == M) dw[i] *= 4;
    }

    for (int d = 1; d < L; d *= 2) {
      int w = 0;
      for (int s = 0, k = 0;;) {
        for (int i = s, j = s + d; i < s + d; ++i, ++j) {
          T *ai = a + i * M, *aj = a + j * M;
          cp(buf, ai, M);
          in_add(ai, aj, M);
          in_sub(buf, aj, M);
          rot(aj, buf, M, w);
        }
        if ((s += 2 * d) >= L) break;
        w -= dw[__builtin_ctz(++k)];
        if (w < 0) w += 2 * M;
      }
    }
  }

  // a <- ab / (x^(2^n)+1)
  int naive_mul(T* a, T* b, int n) {
    int N = 1 << n;
    reset(buf, N);
    for (int i = 0; i < N; i++) {
      int j = 0;
      for (; j < N - i; j++) buf[i + j] += a[i] * b[j];
      for (; j < N; j++) buf[i + j - N] -= a[i] * b[j];
    }
    cp(a, buf, N);
    return 0;
  }

  // a <- ab / (x^(2^n)+1)
  int inplace_mul(T* a, T* b, int n) {
    if (n <= 5) {
      naive_mul(a, b, n);
      return 0;
    }

    int l = (n + 1) / 2;
    int m = n / 2;
    int L = 1 << l, M = 1 << m, N = 1 << n;
    int cnt = 0;

    // R[x] (x^(2^(m+1))-1) R[y] (y^(2^l)-1)
    vector<T> A(N * 2), B(N * 2);
    reset(buf + M, M);
    for (int i = 0, s = 0, ds = 2 * M / L; i < L; i++) {
      // y -> x^{2m/r} y
      cp(buf, a + i * M, M);
      rot(A.data() + i * M * 2, buf, 2 * M, s);
      cp(buf, b + i * M, M);
      rot(B.data() + i * M * 2, buf, 2 * M, s);
      s += ds;
      if (s >= 4 * M) s -= 4 * M;
    }

    fft(A.data(), l, m + 1);
    fft(B.data(), l, m + 1);
    for (int i = 0; i < L; i++) {
      cnt = inplace_mul(A.data() + i * M * 2, B.data() + i * M * 2, m + 1);
    }
    ifft(A.data(), l, m + 1);

    for (int i = 0, s = 0, ds = 2 * M / L; i < L; i++) {
      // y -> x^{-2m/r} y
      cp(buf, A.data() + i * M * 2, 2 * M);
      rot(A.data() + i * M * 2, buf, 2 * M, s);
      s -= ds;
      if (s < 0) s += 4 * M;
    }

    for (int i = 0; i < L; i++) {
      for (int j = 0; j < M; j++) {
        a[i * M + j] = A[i * M * 2 + j];
        if (i == 0) {
          a[i * M + j] -= A[(L - 1) * M * 2 + M + j];
        } else {
          a[i * M + j] += A[(i - 1) * M * 2 + M + j];
        }
      }
    }
    return cnt + l;
  }

  // a <- ab / (x^(2^n)-1)
  int inplace_mul2(T* a, T* b, int n) {
    if (n <= 5) {
      naive_mul(a, b, n);
      return 0;
    }

    int l = (n + 1) / 2;
    int m = n / 2;
    int L = 1 << l, M = 1 << m, N = 1 << n;
    int cnt = 0;

    // R[x] (x^(2^(m+1))-1) R[y] (y^(2^l)-1)
    vector<T> A(N * 2), B(N * 2);
    for (int i = 0; i < L; i++) {
      cp(A.data() + i * M * 2, a + i * M, M);
      cp(B.data() + i * M * 2, b + i * M, M);
    }
    fft(A.data(), l, m + 1);
    fft(B.data(), l, m + 1);
    for (int i = 0; i < L; i++) {
      cnt = inplace_mul(A.data() + i * M * 2, B.data() + i * M * 2, m + 1);
    }
    ifft(A.data(), l, m + 1);
    for (int i = 0; i < L; i++) {
      for (int j = 0; j < M; j++) {
        a[i * M + j] = A[i * M * 2 + j];
        a[i * M + j] += A[(i ? i - 1 : L - 1) * M * 2 + M + j];
      }
    }
    return cnt + l;
  }

  pair<vector<T>, int> multiply(const vector<T>& a, const vector<T>& b) {
    int L = a.size() + b.size() - 1;
    int M = 1, m = 0;
    while (M < L) M *= 2, m++;
    buf = new T[M];
    vector<T> s(M), t(M);
    for (int i = 0; i < (int)a.size(); i++) s[i] = a[i];
    for (int i = 0; i < (int)b.size(); i++) t[i] = b[i];
    int cnt = inplace_mul2(s.data(), t.data(), m);
    vector<T> u(L);
    for (int i = 0; i < L; i++) u[i] = s[i];
    delete[] buf;
    return make_pair(u, cnt);
  }
};

/**
 * @brief Schönhage-Strassen Algorithm(radix-2)
 */
#line 2 "ntt/schoenhage-strassen-radix2.hpp"

template <typename T>
struct Schoenhage_Strassen_radix2 {
  T* buf = nullptr;

  void rot(T* dst, T* src, int s, int d) {
    assert(0 <= d and d < 2 * s);
    bool f = s <= d;
    if (s <= d) d -= s;
    int i = 0;
    if (f) {
      for (; i < s - d; i++) dst[i + d] = -src[i];
      for (; i < s; i++) dst[i + d - s] = src[i];
    } else {
      for (; i < s - d; i++) dst[i + d] = src[i];
      for (; i < s; i++) dst[i + d - s] = -src[i];
    }
  }

  void in_add(T* dst, T* src, int s) {
    for (int i = 0; i < s; i++) dst[i] += src[i];
  }
  void in_sub(T* dst, T* src, int s) {
    for (int i = 0; i < s; i++) dst[i] -= src[i];
  }

  void cp(T* dst, T* src, int s) { memcpy(dst, src, s * sizeof(T)); }
  void reset(T* dst, int s) { fill(dst, dst + s, T()); }

  // R[x] / (1 + x^(2^m)) 上の長さ2^LのFFT
  void fft(T* a, int l, int m) {
    if (l == 0) return;
    int L = 1 << l, M = 1 << m;
    assert(M * 2 >= L);
    assert(buf != nullptr);

    vector<int> dw(l - 1);
    for (int i = 0; i < l - 1; i++) {
      dw[i] = (1 << (l - 2 - i)) + (1 << (l - 1 - i)) - (1 << (l - 1));
      if (dw[i] < 0) dw[i] += L;
      if (L == M) dw[i] *= 2;
      if (2 * L == M) dw[i] *= 4;
    }

    for (int d = L; d >>= 1;) {
      int w = 0;
      for (int s = 0, k = 0;;) {
        for (int i = s, j = s + d; i < s + d; ++i, ++j) {
          T *ai = a + i * M, *aj = a + j * M;
          rot(buf, aj, M, w);
          cp(aj, ai, M);
          in_add(ai, buf, M);
          in_sub(aj, buf, M);
        }
        if ((s += 2 * d) >= L) break;
        w += dw[__builtin_ctz(++k)];
        if (w >= 2 * M) w -= 2 * M;
      }
    }
  }

  // R[x] / (1 + x^(2^m)) 上の長さ2^LのIFFT
  void ifft(T* a, int l, int m) {
    if (l == 0) return;
    int L = 1 << l, M = 1 << m;
    assert(M * 2 >= L);
    assert(buf != nullptr);

    vector<int> dw(l - 1);
    for (int i = 0; i < l - 1; i++) {
      dw[i] = (1 << (l - 2 - i)) + (1 << (l - 1 - i)) - (1 << (l - 1));
      if (dw[i] < 0) dw[i] += L;
      if (L == M) dw[i] *= 2;
      if (2 * L == M) dw[i] *= 4;
    }

    for (int d = 1; d < L; d *= 2) {
      int w = 0;
      for (int s = 0, k = 0;;) {
        for (int i = s, j = s + d; i < s + d; ++i, ++j) {
          T *ai = a + i * M, *aj = a + j * M;
          cp(buf, ai, M);
          in_add(ai, aj, M);
          in_sub(buf, aj, M);
          rot(aj, buf, M, w);
        }
        if ((s += 2 * d) >= L) break;
        w -= dw[__builtin_ctz(++k)];
        if (w < 0) w += 2 * M;
      }
    }
  }

  // a <- ab / (x^(2^n)+1)
  int naive_mul(T* a, T* b, int n) {
    int N = 1 << n;
    reset(buf, N);
    for (int i = 0; i < N; i++) {
      int j = 0;
      for (; j < N - i; j++) buf[i + j] += a[i] * b[j];
      for (; j < N; j++) buf[i + j - N] -= a[i] * b[j];
    }
    cp(a, buf, N);
    return 0;
  }

  // a <- ab / (x^(2^n)+1)
  int inplace_mul(T* a, T* b, int n) {
    if (n <= 5) {
      naive_mul(a, b, n);
      return 0;
    }

    int l = (n + 1) / 2;
    int m = n / 2;
    int L = 1 << l, M = 1 << m, N = 1 << n;
    int cnt = 0;

    // R[x] (x^(2^(m+1))-1) R[y] (y^(2^l)-1)
    vector<T> A(N * 2), B(N * 2);
    reset(buf + M, M);
    for (int i = 0, s = 0, ds = 2 * M / L; i < L; i++) {
      // y -> x^{2m/r} y
      cp(buf, a + i * M, M);
      rot(A.data() + i * M * 2, buf, 2 * M, s);
      cp(buf, b + i * M, M);
      rot(B.data() + i * M * 2, buf, 2 * M, s);
      s += ds;
      if (s >= 4 * M) s -= 4 * M;
    }

    fft(A.data(), l, m + 1);
    fft(B.data(), l, m + 1);
    for (int i = 0; i < L; i++) {
      cnt = inplace_mul(A.data() + i * M * 2, B.data() + i * M * 2, m + 1);
    }
    ifft(A.data(), l, m + 1);

    for (int i = 0, s = 0, ds = 2 * M / L; i < L; i++) {
      // y -> x^{-2m/r} y
      cp(buf, A.data() + i * M * 2, 2 * M);
      rot(A.data() + i * M * 2, buf, 2 * M, s);
      s -= ds;
      if (s < 0) s += 4 * M;
    }

    for (int i = 0; i < L; i++) {
      for (int j = 0; j < M; j++) {
        a[i * M + j] = A[i * M * 2 + j];
        if (i == 0) {
          a[i * M + j] -= A[(L - 1) * M * 2 + M + j];
        } else {
          a[i * M + j] += A[(i - 1) * M * 2 + M + j];
        }
      }
    }
    return cnt + l;
  }

  // a <- ab / (x^(2^n)-1)
  int inplace_mul2(T* a, T* b, int n) {
    if (n <= 5) {
      naive_mul(a, b, n);
      return 0;
    }

    int l = (n + 1) / 2;
    int m = n / 2;
    int L = 1 << l, M = 1 << m, N = 1 << n;
    int cnt = 0;

    // R[x] (x^(2^(m+1))-1) R[y] (y^(2^l)-1)
    vector<T> A(N * 2), B(N * 2);
    for (int i = 0; i < L; i++) {
      cp(A.data() + i * M * 2, a + i * M, M);
      cp(B.data() + i * M * 2, b + i * M, M);
    }
    fft(A.data(), l, m + 1);
    fft(B.data(), l, m + 1);
    for (int i = 0; i < L; i++) {
      cnt = inplace_mul(A.data() + i * M * 2, B.data() + i * M * 2, m + 1);
    }
    ifft(A.data(), l, m + 1);
    for (int i = 0; i < L; i++) {
      for (int j = 0; j < M; j++) {
        a[i * M + j] = A[i * M * 2 + j];
        a[i * M + j] += A[(i ? i - 1 : L - 1) * M * 2 + M + j];
      }
    }
    return cnt + l;
  }

  pair<vector<T>, int> multiply(const vector<T>& a, const vector<T>& b) {
    int L = a.size() + b.size() - 1;
    int M = 1, m = 0;
    while (M < L) M *= 2, m++;
    buf = new T[M];
    vector<T> s(M), t(M);
    for (int i = 0; i < (int)a.size(); i++) s[i] = a[i];
    for (int i = 0; i < (int)b.size(); i++) t[i] = b[i];
    int cnt = inplace_mul2(s.data(), t.data(), m);
    vector<T> u(L);
    for (int i = 0; i < L; i++) u[i] = s[i];
    delete[] buf;
    return make_pair(u, cnt);
  }
};

/**
 * @brief Schönhage-Strassen Algorithm(radix-2)
 */
Back to top page