string/aho-corasick.hpp
Depends on
Verified with
Code
#pragma once
#include "trie.hpp"
template <size_t X = 26, char margin = 'a', bool heavy = false>
struct AhoCorasick : Trie<X + 1, margin> {
using TRIE = Trie<X + 1, margin>;
using TRIE::next;
using TRIE::st;
using TRIE::TRIE;
vector<int> cnt;
void build() {
int n = st.size();
cnt.resize(n);
for (int i = 0; i < n; i++) {
if (heavy) sort(st[i].idxs.begin(), st[i].idxs.end());
cnt[i] = st[i].idxs.size();
}
queue<int> que;
for (int i = 0; i < (int)X; i++) {
if (~next(0, i)) {
next(next(0, i), X) = 0;
que.emplace(next(0, i));
} else {
next(0, i) = 0;
}
}
while (!que.empty()) {
auto &x = st[que.front()];
int fail = x.nxt[X];
cnt[que.front()] += cnt[fail];
que.pop();
for (int i = 0; i < (int)X; i++) {
int &nx = x.nxt[i];
if (nx < 0) {
nx = next(fail, i);
continue;
}
que.emplace(nx);
next(nx, X) = next(fail, i);
if (heavy) {
auto &idx = st[nx].idxs;
auto &idy = st[next(fail, i)].idxs;
vector<int> idz;
set_union(idx.begin(), idx.end(), idy.begin(), idy.end(),
back_inserter(idz));
idx = idz;
}
}
}
}
// heavy
// true : 各パターン文字列に対してマッチした回数を計算
// false : 全てのパターン文字列にマッチした回数の総和
conditional_t<heavy, unordered_map<int, long long>, long long> match(
string s) {
unordered_map<int, int> pos_cnt;
int pos = 0;
for (auto &c : s) {
pos = next(pos, c - margin);
pos_cnt[pos]++;
}
conditional_t<heavy, unordered_map<int, long long>, long long> res{};
for (auto &[key, val] : pos_cnt) {
if constexpr (heavy) {
for (auto &x : st[key].idxs) res[x] += val;
} else {
res += 1LL * cnt[key] * val;
}
}
return res;
}
int count(int pos) { return cnt[pos]; }
};
#line 2 "string/aho-corasick.hpp"
#line 2 "string/trie.hpp"
template <size_t X = 26, char margin = 'a'>
struct Trie {
struct Node {
array<int, X> nxt;
vector<int> idxs;
int idx;
char key;
Node(char c) : idx(-1), key(c) { fill(nxt.begin(), nxt.end(), -1); }
};
vector<Node> st;
Trie(char c = '$') { st.emplace_back(c); }
inline int &next(int i, int j) { return st[i].nxt[j]; }
void add(const string &s, int x) {
int pos = 0;
for (int i = 0; i < (int)s.size(); i++) {
int k = s[i] - margin;
if (~next(pos, k)) {
pos = next(pos, k);
continue;
}
int npos = st.size();
next(pos, k) = npos;
st.emplace_back(s[i]);
pos = npos;
}
st[pos].idx = x;
st[pos].idxs.emplace_back(x);
}
int find(const string &s) {
int pos = 0;
for (int i = 0; i < (int)s.size(); i++) {
int k = s[i] - margin;
if (next(pos, k) < 0) return -1;
pos = next(pos, k);
}
return pos;
}
int move(int pos, char c) {
assert(pos < (int)st.size());
return pos < 0 ? -1 : next(pos, c - margin);
}
int size() const { return st.size(); }
int idx(int pos) { return pos < 0 ? -1 : st[pos].idx; }
vector<int> idxs(int pos) { return pos < 0 ? vector<int>() : st[pos].idxs; }
};
#line 4 "string/aho-corasick.hpp"
template <size_t X = 26, char margin = 'a', bool heavy = false>
struct AhoCorasick : Trie<X + 1, margin> {
using TRIE = Trie<X + 1, margin>;
using TRIE::next;
using TRIE::st;
using TRIE::TRIE;
vector<int> cnt;
void build() {
int n = st.size();
cnt.resize(n);
for (int i = 0; i < n; i++) {
if (heavy) sort(st[i].idxs.begin(), st[i].idxs.end());
cnt[i] = st[i].idxs.size();
}
queue<int> que;
for (int i = 0; i < (int)X; i++) {
if (~next(0, i)) {
next(next(0, i), X) = 0;
que.emplace(next(0, i));
} else {
next(0, i) = 0;
}
}
while (!que.empty()) {
auto &x = st[que.front()];
int fail = x.nxt[X];
cnt[que.front()] += cnt[fail];
que.pop();
for (int i = 0; i < (int)X; i++) {
int &nx = x.nxt[i];
if (nx < 0) {
nx = next(fail, i);
continue;
}
que.emplace(nx);
next(nx, X) = next(fail, i);
if (heavy) {
auto &idx = st[nx].idxs;
auto &idy = st[next(fail, i)].idxs;
vector<int> idz;
set_union(idx.begin(), idx.end(), idy.begin(), idy.end(),
back_inserter(idz));
idx = idz;
}
}
}
}
// heavy
// true : 各パターン文字列に対してマッチした回数を計算
// false : 全てのパターン文字列にマッチした回数の総和
conditional_t<heavy, unordered_map<int, long long>, long long> match(
string s) {
unordered_map<int, int> pos_cnt;
int pos = 0;
for (auto &c : s) {
pos = next(pos, c - margin);
pos_cnt[pos]++;
}
conditional_t<heavy, unordered_map<int, long long>, long long> res{};
for (auto &[key, val] : pos_cnt) {
if constexpr (heavy) {
for (auto &x : st[key].idxs) res[x] += val;
} else {
res += 1LL * cnt[key] * val;
}
}
return res;
}
int count(int pos) { return cnt[pos]; }
};
Back to top page