Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: 乗法的関数のprefix sum
(multiplicative-function/sum-of-multiplicative-function.hpp)

乗法的関数の和

乗法的関数 $f(p)$ のprefix sum

\[S(N) = \sum_{i=1}^N f(i)\]

を $\mathrm{O}\left(\frac{N^{\frac{3}{4}}}{\log N}\right)$ で求めるアルゴリズムを実装したライブラリ。

ただし、素数 $p$ について $f(p)=g(p)$ を満たす多項式 $g(p)$ が存在するとする。

アルゴリズムの概要

前計算として

\[S_p(n) = \sum_{p \leq n | p : \mathrm{prime}} f(i)\]

を $n = k, \left\lfloor \frac{N}{k} \right\rfloor$ $(1\leq k\leq N)$ に対して列挙するアルゴリズムを以下に説明する。

$p$ が素数の時 $f(p)$ は多項式なので、 $p$ の次数ごとに分解すると $f(N)$ は

\[S_c(N) = \sum_{p\leq N | p : \mathrm{prime}} p^c\]

の線形和で表すことが出来る。 $S_c(\lfloor \frac{N}{k} \rfloor)$ は素数カウントのアルゴリズム (いわゆる Lucy DP) を一般的に拡張した方法で高速に求められる。(素数カウントの時に求めた $\pi(n)$ は $c=0$ の時の場合であると言える。)

とおくと $x$ が素数の時に

\[f(x,n) = f(x - 1, n) - f(x-1,\left\lfloor\frac{n}{x}\right\rfloor)x^c\]

が成り立つ。ここで $n \lt x$ のとき $f(x,n) = f(x-1,n)$ 、$x \leq n \lt x^2$ のとき$ f(x,n) = f(x-1,n) - x^c$ となることを利用して $g(x,n) = f(x, n) + S_c(\min(x,n))$ とおくと、 $x$ が素数の時に

\[g(x,n) = \begin{cases} g(x-1,n) & \mathrm{if}\ n \lt x^2 \\ g(x-1,n) - \lbrace g(x-1,\lfloor\frac{n}{x}\rfloor) - S_c(x-1) - 1 \rbrace x^c& \mathrm{otherwise} \end{cases}\]

となる。

\[S_c(\lfloor\sqrt{N}\rfloor,N)=g(\lfloor\sqrt{N}\rfloor,N)-1\]

であるから $h(x,n)=g(x,n)-1$ と補正すると

\[\begin{aligned} S_c(N) &= h(\lfloor\sqrt{N}\rfloor,N)\\ h(0,n) &= -1 + \sum_{0 \leq m \leq n}m^c \end{aligned}\\ h(x,n) = \begin{cases} h(x-1,n) & \mathrm{if}\ x\ \mathrm{is}\ \mathrm{not}\ \mathrm{prime}\ \cup\ n \lt x^2 \\ h(x-1,n) - \lbrace h(x-1,\lfloor\frac{n}{x}\rfloor) - S_c(x-1) \rbrace x^c& \mathrm{otherwise} \end{cases}\]

を得る。(なお、 $S_c(x-1)=h(x-1,x-1)$ である。)以上より、素数カウントと同様のアルゴリズムで DP を行うことで $\mathrm{O}\left(\frac{N^{\frac{3}{4}}}{\log N}\right)$ で前計算ができる。

以上に説明したアルゴリズムによって $S_p(\left\lfloor \frac{N}{k} \right\rfloor)$ を列挙することが出来た。$S_p(\left\lfloor\frac{N}{k}\right\rfloor)$ から $S(N)$ を求めるアルゴリズムには洲閣篩(Zhouge sieve) や min_25篩 などが知られているが、ここでは Black Algorithm を用いた解法を説明する。 参考文献

まず、以下の条件を満たす $1$ から $N$ の頂点ラベルがついた木を考える。

この木の上を DFS して訪れた頂点 $n$ に対して $f(n)$ を加算するという操作を行うと $S(N)$ は計算できるが $\mathrm{O}(N)$ かかってしまう。そこで一工夫して、訪れた頂点 $n$ の子 $c$ について $f(c)$ を計算することを考える。

今、葉でない木上の頂点 $n$ および $f(n)$ が分かっているとする。この時、子の頂点に書かれた数の集合 $T$ に対して $\sum_{c \in T}f(c)$ は以下に説明する方法で高速に計算することが出来る。

\[\sum_{i+1\leq j\leq l} f(np_{j})=f(n)\sum_{i+1\leq j\leq l} f(p_{j})=f(n)\left(S_p\left(\left\lfloor\frac{N}{n}\right\rfloor\right)-S_p\left(p_i\right)\right)\]

と $S_p$ を用いて高速に計算できる。

以上のアルゴリズムを用いれば、$N$ 頂点の木の葉でない頂点を適切な情報をもってDFSすることで高速に $S(N)$ を求めることが出来る。

DFS の計算量は葉でないノードの個数に一致する。この個数は $\mathrm{O}\left(\frac{N^{\frac{3}{4}}}{\log N}\right)$ らしい。(参考文献のリンク先に書いてあるが中国語なので読めていない…) 追記:葉でないノードの個数は $\mathrm{O}(N^{1 - \epsilon})$ が正しい。参考
ただし定数倍は極めて軽く、$N = 10^{11}$ で葉でないノードの個数はおよそ $8.8 \times 10^8$ 個で DFS が一般的な実行時間に十分収まる。
以上より $S(N)$ を実用上高速に求めることが出来た。

関連:yukicoder No.1322 Totient Bound $\pi(N)$ の列挙と木上の DFS を使うとかなり見通しよく解くことが出来る。提出

ライブラリの使い方は verify のコードを適宜参照のこと。

Depends on

Verified with

Code

#pragma once

#include "../prime/prime-enumerate.hpp"

// f(p, c) : f(p^c) の値を返す
template <typename T, T (*f)(long long, long long)>
struct mf_prefix_sum {
  using i64 = long long;

  i64 M, sq, s;
  vector<int> p;
  int ps;
  vector<T> buf;
  T ans;

  mf_prefix_sum(i64 m) : M(m) {
    assert(m < (1LL << 42));
    sq = sqrt(M);
    while (sq * sq > M) sq--;
    while ((sq + 1) * (sq + 1) <= M) sq++;

    if (M != 0) {
      i64 hls = md(M, sq);
      if (hls != 1 && md(M, hls - 1) == sq) hls--;
      s = hls + sq;

      p = prime_enumerate(sq);
      ps = p.size();
      ans = T{};
    }
  }

  // 素数の個数関数に関するテーブル
  vector<T> pi_table() {
    if (M == 0) return {};
    i64 hls = md(M, sq);
    if (hls != 1 && md(M, hls - 1) == sq) hls--;

    vector<i64> hl(hls);
    for (int i = 1; i < hls; i++) hl[i] = md(M, i) - 1;

    vector<int> hs(sq + 1);
    iota(begin(hs), end(hs), -1);

    int pi = 0;
    for (auto& x : p) {
      i64 x2 = i64(x) * x;
      i64 imax = min<i64>(hls, md(M, x2) + 1);
      for (i64 i = 1, ix = x; i < imax; ++i, ix += x) {
        hl[i] -= (ix < hls ? hl[ix] : hs[md(M, ix)]) - pi;
      }
      for (int n = sq; n >= x2; n--) hs[n] -= hs[md(n, x)] - pi;
      pi++;
    }

    vector<T> res;
    res.reserve(2 * sq + 10);
    for (auto& x : hl) res.push_back(x);
    for (int i = hs.size(); --i;) res.push_back(hs[i]);
    assert((int)res.size() == s);
    return res;
  }

  // 素数の prefix sum に関するテーブル
  vector<T> prime_sum_table() {
    if (M == 0) return {};
    i64 hls = md(M, sq);
    if (hls != 1 && md(M, hls - 1) == sq) hls--;

    vector<T> h(s);
    T inv2 = T{2}.inverse();
    for (int i = 1; i < hls; i++) {
      T x = md(M, i);
      h[i] = x * (x + 1) * inv2 - 1;
    }
    for (int i = 1; i <= sq; i++) {
      T x = i;
      h[s - i] = x * (x + 1) / 2 - 1;
    }

    for (auto& x : p) {
      T xt = x;
      T pi = h[s - x + 1];
      i64 x2 = i64(x) * x;
      i64 imax = min<i64>(hls, md(M, x2) + 1);
      i64 ix = x;
      for (i64 i = 1; i < imax; ++i, ix += x) {
        h[i] -= ((ix < hls ? h[ix] : h[s - md(M, ix)]) - pi) * xt;
      }
      for (int n = sq; n >= x2; n--) {
        h[s - n] -= (h[s - md(n, x)] - pi) * xt;
      }
    }

    assert((int)h.size() == s);
    return h;
  }

  void dfs(int i, int c, i64 prod, T cur) {
    ans += cur * f(p[i], c + 1);
    i64 lim = md(M, prod);
    if (lim >= 1LL * p[i] * p[i]) dfs(i, c + 1, p[i] * prod, cur);
    cur *= f(p[i], c);
    ans += cur * (buf[idx(lim)] - buf[idx(p[i])]);
    int j = i + 1;
    // M < 2**42 -> p_j < 2**21 -> (p_j)^3 < 2**63
    for (; j < ps && 1LL * p[j] * p[j] * p[j] <= lim; j++) {
      dfs(j, 1, prod * p[j], cur);
    }
    for (; j < ps && 1LL * p[j] * p[j] <= lim; j++) {
      T sm = f(p[j], 2);
      int id1 = idx(md(lim, p[j])), id2 = idx(p[j]);
      sm += f(p[j], 1) * (buf[id1] - buf[id2]);
      ans += cur * sm;
    }
  }

  // fprime 破壊的
  T run(vector<T>& fprime) {
    if (M == 0) return {};
    set_buf(fprime);
    assert((int)buf.size() == s);
    ans = buf[idx(M)] + 1;
    for (int i = 0; i < ps; i++) dfs(i, 1, p[i], 1);
    return ans;
  }

  i64 md(i64 n, i64 d) { return double(n) / d; }
  i64 idx(i64 n) { return n <= sq ? s - n : md(M, n); }
  void set_buf(vector<T>& _buf) { swap(buf, _buf); }
};

/**
 * @brief 乗法的関数のprefix sum
 * @docs docs/multiplicative-function/sum-of-multiplicative-function.md
 */
#line 2 "multiplicative-function/sum-of-multiplicative-function.hpp"

#line 2 "prime/prime-enumerate.hpp"

// Prime Sieve {2, 3, 5, 7, 11, 13, 17, ...}
vector<int> prime_enumerate(int N) {
  vector<bool> sieve(N / 3 + 1, 1);
  for (int p = 5, d = 4, i = 1, sqn = sqrt(N); p <= sqn; p += d = 6 - d, i++) {
    if (!sieve[i]) continue;
    for (int q = p * p / 3, r = d * p / 3 + (d * p % 3 == 2), s = 2 * p,
             qe = sieve.size();
         q < qe; q += r = s - r)
      sieve[q] = 0;
  }
  vector<int> ret{2, 3};
  for (int p = 5, d = 4, i = 1; p <= N; p += d = 6 - d, i++)
    if (sieve[i]) ret.push_back(p);
  while (!ret.empty() && ret.back() > N) ret.pop_back();
  return ret;
}
#line 4 "multiplicative-function/sum-of-multiplicative-function.hpp"

// f(p, c) : f(p^c) の値を返す
template <typename T, T (*f)(long long, long long)>
struct mf_prefix_sum {
  using i64 = long long;

  i64 M, sq, s;
  vector<int> p;
  int ps;
  vector<T> buf;
  T ans;

  mf_prefix_sum(i64 m) : M(m) {
    assert(m < (1LL << 42));
    sq = sqrt(M);
    while (sq * sq > M) sq--;
    while ((sq + 1) * (sq + 1) <= M) sq++;

    if (M != 0) {
      i64 hls = md(M, sq);
      if (hls != 1 && md(M, hls - 1) == sq) hls--;
      s = hls + sq;

      p = prime_enumerate(sq);
      ps = p.size();
      ans = T{};
    }
  }

  // 素数の個数関数に関するテーブル
  vector<T> pi_table() {
    if (M == 0) return {};
    i64 hls = md(M, sq);
    if (hls != 1 && md(M, hls - 1) == sq) hls--;

    vector<i64> hl(hls);
    for (int i = 1; i < hls; i++) hl[i] = md(M, i) - 1;

    vector<int> hs(sq + 1);
    iota(begin(hs), end(hs), -1);

    int pi = 0;
    for (auto& x : p) {
      i64 x2 = i64(x) * x;
      i64 imax = min<i64>(hls, md(M, x2) + 1);
      for (i64 i = 1, ix = x; i < imax; ++i, ix += x) {
        hl[i] -= (ix < hls ? hl[ix] : hs[md(M, ix)]) - pi;
      }
      for (int n = sq; n >= x2; n--) hs[n] -= hs[md(n, x)] - pi;
      pi++;
    }

    vector<T> res;
    res.reserve(2 * sq + 10);
    for (auto& x : hl) res.push_back(x);
    for (int i = hs.size(); --i;) res.push_back(hs[i]);
    assert((int)res.size() == s);
    return res;
  }

  // 素数の prefix sum に関するテーブル
  vector<T> prime_sum_table() {
    if (M == 0) return {};
    i64 hls = md(M, sq);
    if (hls != 1 && md(M, hls - 1) == sq) hls--;

    vector<T> h(s);
    T inv2 = T{2}.inverse();
    for (int i = 1; i < hls; i++) {
      T x = md(M, i);
      h[i] = x * (x + 1) * inv2 - 1;
    }
    for (int i = 1; i <= sq; i++) {
      T x = i;
      h[s - i] = x * (x + 1) / 2 - 1;
    }

    for (auto& x : p) {
      T xt = x;
      T pi = h[s - x + 1];
      i64 x2 = i64(x) * x;
      i64 imax = min<i64>(hls, md(M, x2) + 1);
      i64 ix = x;
      for (i64 i = 1; i < imax; ++i, ix += x) {
        h[i] -= ((ix < hls ? h[ix] : h[s - md(M, ix)]) - pi) * xt;
      }
      for (int n = sq; n >= x2; n--) {
        h[s - n] -= (h[s - md(n, x)] - pi) * xt;
      }
    }

    assert((int)h.size() == s);
    return h;
  }

  void dfs(int i, int c, i64 prod, T cur) {
    ans += cur * f(p[i], c + 1);
    i64 lim = md(M, prod);
    if (lim >= 1LL * p[i] * p[i]) dfs(i, c + 1, p[i] * prod, cur);
    cur *= f(p[i], c);
    ans += cur * (buf[idx(lim)] - buf[idx(p[i])]);
    int j = i + 1;
    // M < 2**42 -> p_j < 2**21 -> (p_j)^3 < 2**63
    for (; j < ps && 1LL * p[j] * p[j] * p[j] <= lim; j++) {
      dfs(j, 1, prod * p[j], cur);
    }
    for (; j < ps && 1LL * p[j] * p[j] <= lim; j++) {
      T sm = f(p[j], 2);
      int id1 = idx(md(lim, p[j])), id2 = idx(p[j]);
      sm += f(p[j], 1) * (buf[id1] - buf[id2]);
      ans += cur * sm;
    }
  }

  // fprime 破壊的
  T run(vector<T>& fprime) {
    if (M == 0) return {};
    set_buf(fprime);
    assert((int)buf.size() == s);
    ans = buf[idx(M)] + 1;
    for (int i = 0; i < ps; i++) dfs(i, 1, p[i], 1);
    return ans;
  }

  i64 md(i64 n, i64 d) { return double(n) / d; }
  i64 idx(i64 n) { return n <= sq ? s - n : md(M, n); }
  void set_buf(vector<T>& _buf) { swap(buf, _buf); }
};

/**
 * @brief 乗法的関数のprefix sum
 * @docs docs/multiplicative-function/sum-of-multiplicative-function.md
 */
Back to top page