#include "marathon/multi-armed-bandit.hpp"
#include <algorithm> #include <cassert> #include <cmath> #include <numeric> #include <vector> using namespace std; #include "../misc/rng.hpp" // N 択, 報酬最大化 struct MultiArmedBandit { MultiArmedBandit(int n) : N(n), last(-1), iter(0), thres(N * 5), num(N), v(N), e(N), t(1) {} int N, last; long long iter, thres; vector<long long> num; vector<double> v, e; double t; int play() { assert(last == -1); iter++; if (iter <= thres) return last = iter % N; double s = accumulate(begin(e), end(e), 0.0); double x = rnd() * s; for (int i = 0; i < N; i++) { if ((x -= e[i]) <= 0) return last = i; } return last = N - 1; } // 重み付け用の関数 double f(double x) { return exp(x / t); } void reward(double y) { assert(last != -1); v[last] += y; num[last] += 1; e[last] = f(v[last] / num[last]); last = -1; static double u = 1.0; static double du = 0.01; // iter % thres == 0 になったら t を再決定 if (iter % thres == 0) { u = max(0.7, u - du); double average = accumulate(begin(v), end(v), 0.0) / thres; t = average < 0.0 ? 1.0 : pow(average, u); for (int i = 0; i < N; i++) e[i] = f(v[i] / num[i]); } } int best() { return max_element(begin(e), end(e)) - begin(e); } };
#line 1 "marathon/multi-armed-bandit.hpp" #include <algorithm> #include <cassert> #include <cmath> #include <numeric> #include <vector> using namespace std; #line 2 "misc/rng.hpp" #line 2 "internal/internal-seed.hpp" #include <chrono> using namespace std; namespace internal { unsigned long long non_deterministic_seed() { unsigned long long m = chrono::duration_cast<chrono::nanoseconds>( chrono::high_resolution_clock::now().time_since_epoch()) .count(); m ^= 9845834732710364265uLL; m ^= m << 24, m ^= m >> 31, m ^= m << 35; return m; } unsigned long long deterministic_seed() { return 88172645463325252UL; } // 64 bit の seed 値を生成 (手元では seed 固定) // 連続で呼び出すと同じ値が何度も返ってくるので注意 // #define RANDOMIZED_SEED するとシードがランダムになる unsigned long long seed() { #if defined(NyaanLocal) && !defined(RANDOMIZED_SEED) return deterministic_seed(); #else return non_deterministic_seed(); #endif } } // namespace internal #line 4 "misc/rng.hpp" namespace my_rand { using i64 = long long; using u64 = unsigned long long; // [0, 2^64 - 1) u64 rng() { static u64 _x = internal::seed(); return _x ^= _x << 7, _x ^= _x >> 9; } // [l, r] i64 rng(i64 l, i64 r) { assert(l <= r); return l + rng() % u64(r - l + 1); } // [l, r) i64 randint(i64 l, i64 r) { assert(l < r); return l + rng() % u64(r - l); } // choose n numbers from [l, r) without overlapping vector<i64> randset(i64 l, i64 r, i64 n) { assert(l <= r && n <= r - l); unordered_set<i64> s; for (i64 i = n; i; --i) { i64 m = randint(l, r + 1 - i); if (s.find(m) != s.end()) m = r - i; s.insert(m); } vector<i64> ret; for (auto& x : s) ret.push_back(x); sort(begin(ret), end(ret)); return ret; } // [0.0, 1.0) double rnd() { return rng() * 5.42101086242752217004e-20; } // [l, r) double rnd(double l, double r) { assert(l < r); return l + rnd() * (r - l); } template <typename T> void randshf(vector<T>& v) { int n = v.size(); for (int i = 1; i < n; i++) swap(v[i], v[randint(0, i + 1)]); } } // namespace my_rand using my_rand::randint; using my_rand::randset; using my_rand::randshf; using my_rand::rnd; using my_rand::rng; #line 9 "marathon/multi-armed-bandit.hpp" // N 択, 報酬最大化 struct MultiArmedBandit { MultiArmedBandit(int n) : N(n), last(-1), iter(0), thres(N * 5), num(N), v(N), e(N), t(1) {} int N, last; long long iter, thres; vector<long long> num; vector<double> v, e; double t; int play() { assert(last == -1); iter++; if (iter <= thres) return last = iter % N; double s = accumulate(begin(e), end(e), 0.0); double x = rnd() * s; for (int i = 0; i < N; i++) { if ((x -= e[i]) <= 0) return last = i; } return last = N - 1; } // 重み付け用の関数 double f(double x) { return exp(x / t); } void reward(double y) { assert(last != -1); v[last] += y; num[last] += 1; e[last] = f(v[last] / num[last]); last = -1; static double u = 1.0; static double du = 0.01; // iter % thres == 0 になったら t を再決定 if (iter % thres == 0) { u = max(0.7, u - du); double average = accumulate(begin(v), end(v), 0.0) / thres; t = average < 0.0 ? 1.0 : pow(average, u); for (int i = 0; i < N; i++) e[i] = f(v[i] / num[i]); } } int best() { return max_element(begin(e), end(e)) - begin(e); } };