Nyaan's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:warning: marathon/multi-armed-bandit.hpp

Depends on

Code

#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <vector>
using namespace std;

#include "../misc/rng.hpp"

// N 択, 報酬最大化
struct MultiArmedBandit {
  MultiArmedBandit(int n)
      : N(n), last(-1), iter(0), thres(N * 5), num(N), v(N), e(N), t(1) {}

  int N, last;
  long long iter, thres;
  vector<long long> num;
  vector<double> v, e;
  double t;

  int play() {
    assert(last == -1);
    iter++;
    if (iter <= thres) return last = iter % N;

    double s = accumulate(begin(e), end(e), 0.0);
    double x = rnd() * s;
    for (int i = 0; i < N; i++) {
      if ((x -= e[i]) <= 0) return last = i;
    }
    return last = N - 1;
  }

  // 重み付け用の関数
  double f(double x) { return exp(x / t); }

  void reward(double y) {
    assert(last != -1);
    v[last] += y;
    num[last] += 1;
    e[last] = f(v[last] / num[last]);
    last = -1;

    static double u = 1.0;
    static double du = 0.01;
    // iter % thres == 0 になったら t を再決定
    if (iter % thres == 0) {
      u = max(0.7, u - du);
      double average = accumulate(begin(v), end(v), 0.0) / thres;
      t = average < 0.0 ? 1.0 : pow(average, u);
      for (int i = 0; i < N; i++) e[i] = f(v[i] / num[i]);
    }
  }
  int best() { return max_element(begin(e), end(e)) - begin(e); }
};
#line 1 "marathon/multi-armed-bandit.hpp"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <vector>
using namespace std;

#line 2 "misc/rng.hpp"

#line 2 "internal/internal-seed.hpp"

#include <chrono>
using namespace std;

namespace internal {
unsigned long long non_deterministic_seed() {
  unsigned long long m =
      chrono::duration_cast<chrono::nanoseconds>(
          chrono::high_resolution_clock::now().time_since_epoch())
          .count();
  m ^= 9845834732710364265uLL;
  m ^= m << 24, m ^= m >> 31, m ^= m << 35;
  return m;
}
unsigned long long deterministic_seed() { return 88172645463325252UL; }

// 64 bit の seed 値を生成 (手元では seed 固定)
// 連続で呼び出すと同じ値が何度も返ってくるので注意
// #define RANDOMIZED_SEED するとシードがランダムになる
unsigned long long seed() {
#if defined(NyaanLocal) && !defined(RANDOMIZED_SEED)
  return deterministic_seed();
#else
  return non_deterministic_seed();
#endif
}

}  // namespace internal
#line 4 "misc/rng.hpp"

namespace my_rand {
using i64 = long long;
using u64 = unsigned long long;

// [0, 2^64 - 1)
u64 rng() {
  static u64 _x = internal::seed();
  return _x ^= _x << 7, _x ^= _x >> 9;
}

// [l, r]
i64 rng(i64 l, i64 r) {
  assert(l <= r);
  return l + rng() % u64(r - l + 1);
}

// [l, r)
i64 randint(i64 l, i64 r) {
  assert(l < r);
  return l + rng() % u64(r - l);
}

// choose n numbers from [l, r) without overlapping
vector<i64> randset(i64 l, i64 r, i64 n) {
  assert(l <= r && n <= r - l);
  unordered_set<i64> s;
  for (i64 i = n; i; --i) {
    i64 m = randint(l, r + 1 - i);
    if (s.find(m) != s.end()) m = r - i;
    s.insert(m);
  }
  vector<i64> ret;
  for (auto& x : s) ret.push_back(x);
  sort(begin(ret), end(ret));
  return ret;
}

// [0.0, 1.0)
double rnd() { return rng() * 5.42101086242752217004e-20; }
// [l, r)
double rnd(double l, double r) {
  assert(l < r);
  return l + rnd() * (r - l);
}

template <typename T>
void randshf(vector<T>& v) {
  int n = v.size();
  for (int i = 1; i < n; i++) swap(v[i], v[randint(0, i + 1)]);
}

}  // namespace my_rand

using my_rand::randint;
using my_rand::randset;
using my_rand::randshf;
using my_rand::rnd;
using my_rand::rng;
#line 9 "marathon/multi-armed-bandit.hpp"

// N 択, 報酬最大化
struct MultiArmedBandit {
  MultiArmedBandit(int n)
      : N(n), last(-1), iter(0), thres(N * 5), num(N), v(N), e(N), t(1) {}

  int N, last;
  long long iter, thres;
  vector<long long> num;
  vector<double> v, e;
  double t;

  int play() {
    assert(last == -1);
    iter++;
    if (iter <= thres) return last = iter % N;

    double s = accumulate(begin(e), end(e), 0.0);
    double x = rnd() * s;
    for (int i = 0; i < N; i++) {
      if ((x -= e[i]) <= 0) return last = i;
    }
    return last = N - 1;
  }

  // 重み付け用の関数
  double f(double x) { return exp(x / t); }

  void reward(double y) {
    assert(last != -1);
    v[last] += y;
    num[last] += 1;
    e[last] = f(v[last] / num[last]);
    last = -1;

    static double u = 1.0;
    static double du = 0.01;
    // iter % thres == 0 になったら t を再決定
    if (iter % thres == 0) {
      u = max(0.7, u - du);
      double average = accumulate(begin(v), end(v), 0.0) / thres;
      t = average < 0.0 ? 1.0 : pow(average, u);
      for (int i = 0; i < N; i++) e[i] = f(v[i] / num[i]);
    }
  }
  int best() { return max_element(begin(e), end(e)) - begin(e); }
};
Back to top page