marathon/multi-armed-bandit.hpp
Depends on
Code
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <vector>
using namespace std;
#include "../misc/rng.hpp"
// N 択, 報酬最大化
struct MultiArmedBandit {
MultiArmedBandit(int n)
: N(n), last(-1), iter(0), thres(N * 5), num(N), v(N), e(N), t(1) {}
int N, last;
long long iter, thres;
vector<long long> num;
vector<double> v, e;
double t;
int play() {
assert(last == -1);
iter++;
if (iter <= thres) return last = iter % N;
double s = accumulate(begin(e), end(e), 0.0);
double x = rnd() * s;
for (int i = 0; i < N; i++) {
if ((x -= e[i]) <= 0) return last = i;
}
return last = N - 1;
}
// 重み付け用の関数
double f(double x) { return exp(x / t); }
void reward(double y) {
assert(last != -1);
v[last] += y;
num[last] += 1;
e[last] = f(v[last] / num[last]);
last = -1;
static double u = 1.0;
static double du = 0.01;
// iter % thres == 0 になったら t を再決定
if (iter % thres == 0) {
u = max(0.7, u - du);
double average = accumulate(begin(v), end(v), 0.0) / thres;
t = average < 0.0 ? 1.0 : pow(average, u);
for (int i = 0; i < N; i++) e[i] = f(v[i] / num[i]);
}
}
int best() { return max_element(begin(e), end(e)) - begin(e); }
};
#line 1 "marathon/multi-armed-bandit.hpp"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <vector>
using namespace std;
#line 2 "misc/rng.hpp"
#line 2 "internal/internal-seed.hpp"
#include <chrono>
using namespace std;
namespace internal {
unsigned long long non_deterministic_seed() {
unsigned long long m =
chrono::duration_cast<chrono::nanoseconds>(
chrono::high_resolution_clock::now().time_since_epoch())
.count();
m ^= 9845834732710364265uLL;
m ^= m << 24, m ^= m >> 31, m ^= m << 35;
return m;
}
unsigned long long deterministic_seed() { return 88172645463325252UL; }
// 64 bit の seed 値を生成 (手元では seed 固定)
// 連続で呼び出すと同じ値が何度も返ってくるので注意
// #define RANDOMIZED_SEED するとシードがランダムになる
unsigned long long seed() {
#if defined(NyaanLocal) && !defined(RANDOMIZED_SEED)
return deterministic_seed();
#else
return non_deterministic_seed();
#endif
}
} // namespace internal
#line 4 "misc/rng.hpp"
namespace my_rand {
using i64 = long long;
using u64 = unsigned long long;
// [0, 2^64 - 1)
u64 rng() {
static u64 _x = internal::seed();
return _x ^= _x << 7, _x ^= _x >> 9;
}
// [l, r]
i64 rng(i64 l, i64 r) {
assert(l <= r);
return l + rng() % u64(r - l + 1);
}
// [l, r)
i64 randint(i64 l, i64 r) {
assert(l < r);
return l + rng() % u64(r - l);
}
// choose n numbers from [l, r) without overlapping
vector<i64> randset(i64 l, i64 r, i64 n) {
assert(l <= r && n <= r - l);
unordered_set<i64> s;
for (i64 i = n; i; --i) {
i64 m = randint(l, r + 1 - i);
if (s.find(m) != s.end()) m = r - i;
s.insert(m);
}
vector<i64> ret;
for (auto& x : s) ret.push_back(x);
sort(begin(ret), end(ret));
return ret;
}
// [0.0, 1.0)
double rnd() { return rng() * 5.42101086242752217004e-20; }
// [l, r)
double rnd(double l, double r) {
assert(l < r);
return l + rnd() * (r - l);
}
template <typename T>
void randshf(vector<T>& v) {
int n = v.size();
for (int i = 1; i < n; i++) swap(v[i], v[randint(0, i + 1)]);
}
} // namespace my_rand
using my_rand::randint;
using my_rand::randset;
using my_rand::randshf;
using my_rand::rnd;
using my_rand::rng;
#line 9 "marathon/multi-armed-bandit.hpp"
// N 択, 報酬最大化
struct MultiArmedBandit {
MultiArmedBandit(int n)
: N(n), last(-1), iter(0), thres(N * 5), num(N), v(N), e(N), t(1) {}
int N, last;
long long iter, thres;
vector<long long> num;
vector<double> v, e;
double t;
int play() {
assert(last == -1);
iter++;
if (iter <= thres) return last = iter % N;
double s = accumulate(begin(e), end(e), 0.0);
double x = rnd() * s;
for (int i = 0; i < N; i++) {
if ((x -= e[i]) <= 0) return last = i;
}
return last = N - 1;
}
// 重み付け用の関数
double f(double x) { return exp(x / t); }
void reward(double y) {
assert(last != -1);
v[last] += y;
num[last] += 1;
e[last] = f(v[last] / num[last]);
last = -1;
static double u = 1.0;
static double du = 0.01;
// iter % thres == 0 になったら t を再決定
if (iter % thres == 0) {
u = max(0.7, u - du);
double average = accumulate(begin(v), end(v), 0.0) / thres;
t = average < 0.0 ? 1.0 : pow(average, u);
for (int i = 0; i < N; i++) e[i] = f(v[i] / num[i]);
}
}
int best() { return max_element(begin(e), end(e)) - begin(e); }
};
Back to top page