Trie,(字典树、前缀树、单词查找树)是一棵 “非典型” 的多叉树模型。

TrieNode 中并没有直接保存字符值的数据成员,而是通过字母映射表 next 和一个父节点预知所有子节点。

字典树应用

  • 检索字符串
  • AC 自动机(trie 是 AC 自动机的一部分)
  • 维护异或极值 ———— 将数的二进制看作一个字符串,建构 01-trie
  • 维护异或和

基础代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
static constexpr int MX = 2E5 + 10;

int trie[26][MX] {};
int cnt[MX] {};
int tot {};
bool isEnd[MX] {};

int f(char ch) {
return ch - 'a';
}

void insert(const string& s) {
int p = 0;
for (char ch : s) {
int o = f(ch);
if (!trie[o][p]) {
trie[o][p] = ++tot;
}
p = trie[o][p];
cnt[p]++;
}
isEnd[p] = true;
}

bool find(const string& s) {
int p = 0;
for (char ch : s) {
int o = f(ch);
if (!trie[o][p]) {
return false;
}
p = trie[o][p];
}
return isEnd[p];
}

bool startsWith(const string& s) {
int p = 0;
for (char ch : s) {
int o = f(ch);
if (!trie[o][p]) {
return false;
}
p = trie[o][p];
}
return true;
}

// 以 s(非空)为前缀的模式串的个数
int getCnt(const string& s) {
int p = 0;
for (char ch : s) {
int o = f(ch);
if (!trie[o][p]) {
return 0;
}
p = trie[o][p];
}
return cnt[p];
}

01-Trie 代码实现

常数较大,按需取舍

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
static constexpr i64 inf = 1E18;

struct TrieNode {
array<int, 2> son {-1, -1};
int cnt = 0;
i64 mn = numeric_limits<i64>::max(); // 子树里的最小值
};

class Trie {
vector<TrieNode> trie;
int N = 0;

public:
Trie(int n) : N(n) { trie.push_back(TrieNode()); }

// n: 预分配的节点数
void reserve(int n) { trie.reserve(n * (N + 1)); }

int newNode() {
trie.push_back(TrieNode());
return trie.size() - 1;
}

void insert(i64 x) {
int p = 0;
trie[p].cnt++;
trie[p].mn = min(trie[p].mn, x);
for (int i = N - 1; i >= 0; i--) {
int o = x >> i & 1;
if (trie[p].son[o] == -1) {
trie[p].son[o] = newNode();
}
p = trie[p].son[o];
trie[p].cnt++;
trie[p].mn = min(trie[p].mn, x);
}
}

// ! remove 中没有维护 mn
void remove(i64 x) {
int p = 0;
trie[p].cnt--;
for (int i = N - 1; i >= 0; i--) {
int o = x >> i & 1;
if (trie[p].son[o] == -1) {
trie[p].son[o] = newNode();
}
p = trie[p].son[o];
trie[p].cnt--;
}
}

// 也可以用哈希表
i64 maxXor(i64 x) const {
if (trie.size() == 1) {
return -inf;
}
i64 ans = 0;
int p = 0;
for (int i = N - 1; i >= 0; i--) {
int o = x >> i & 1;
if (trie[p].son[o ^ 1] != -1 && trie[trie[p].son[o ^ 1]].cnt > 0) {
ans |= 1ll << i;
o ^= 1;
}
p = trie[p].son[o];
}
return ans;
}

i64 minXor(i64 x) const {
if (trie.size() == 1) {
return inf;
}
i64 ans = 0;
int p = 0;
for (int i = N - 1; i >= 0; i--) {
int o = x >> i & 1;
if (trie[p].son[o] == -1 || trie[trie[p].son[o]].cnt <= 0) {
ans |= 1ll << i;
o ^= 1;
}
p = trie[p].son[o];
}
return ans;
}

// 返回 x 与 trie 上所有数的第 k 大异或值
// k 从 1 开始,超过总数则返回 0
i64 kth_maxXor(i64 x, int k) const {
i64 ans = 0;
int p = 0;
for (int i = N - 1; i >= 0; i--) {
int o = x >> i & 1;
int cnt = (trie[p].son[o ^ 1] == -1 ? 0 : trie[trie[p].son[o ^ 1]].cnt);
if (k <= cnt) {
ans |= 1ll << i;
o ^= 1;
} else {
k -= cnt;
}
p = trie[p].son[o];
if (p == -1) {
break;
}
}
return ans;
}

// 统计与 x 异或值 ≤ limit 的元素个数
// 也可以用哈希表
i64 countXorWithLimit(i64 x, i64 limit) const {
// 核心原理是,当 limit+1 的某一位是 1 的时候,若该位异或值取 0,则后面的位是可以取任意数字的
// 如果在 limit 上而不是 limit+1 上讨论,就要单独处理走到叶子的情况了(恰好等于 limit)
limit++;
i64 ans = 0;
int p = 0;
for (int i = N - 1; i >= 0 && p != -1; i--) {
int o = x >> i & 1;
if (limit >> i & 1) {
if (trie[p].son[o] != -1) {
ans += trie[trie[p].son[o]].cnt;
}
o ^= 1;
}
p = trie[p].son[o];
}
return ans;
}

// x 与 trie 上所有 ≤ limit 的数的最大异或值
// 不存在时返回 -1
i64 maxXorWithLimit(i64 x, i64 limit) const {
if (trie[0].mn > limit) {
return -1;
}
i64 ans = 0;
int p = 0;
for (int i = N - 1; i >= 0; --i) {
int o = (x >> i) & 1;
if (trie[p].son[o ^ 1] != -1 && trie[trie[p].son[o ^ 1]].cnt > 0 && trie[trie[p].son[o ^ 1]].mn <= limit) {
ans |= 1ll << i;
o ^= 1;
}
p = trie[p].son[o];
}
return ans;
}

// x 与 trie 上所有数的异或 <= limit 的最大异或值
// 不存在时返回 -1
i64 maxXorWithLimitXor(i64 x, i64 limit) const {
limit++;
i64 ans = 0;
int p = 0;
// 记录最后一次我们“仍能走 0 分支”,但 limit 那位是 1 的情况
int last_p = -1, last_i = -1;
i64 last_ans = 0;
for (int i = N - 1; i >= 0 && p != -1; i--) {
int o = x >> i & 1;
if (limit >> i & 1) {
// 分两种:如果走 son[o],当前位异或 0,依然 < limit 在这位的 1
if (trie[p].son[o] != -1) {
last_p = trie[p].son[o];
last_i = i - 1;
last_ans = ans;
}
// 如果走 son[o^1],当前位异或 1
if (trie[p].son[o ^ 1] != -1) {
ans |= 1ll << i;
}
// 实际走的还是 o^1
o ^= 1;
}
p = trie[p].son[o];
}
if (last_p == -1) {
return -1;
}
ans = last_ans;
p = last_p;
for (int i = last_i; i >= 0; i--) {
int o = x >> i & 1;
if (trie[p].son[o ^ 1] != -1 && trie[trie[p].son[o ^ 1]].cnt > 0) {
ans |= 1ll << i;
o ^= 1;
}
p = trie[p].son[o];
}
return ans;
}

// 完全图,边权为 a[u]^a[v],求 MST
// Boruvka 算法,分治连边
template <integral T>
auto xorMST(const vector<T>& a) {
T ans {};
auto dfs = [&](auto&& dfs, auto& a, int p) {
if (a.empty() || p < 0) {
return;
}
vector<T> b[2];
b[0].reserve(a.size());
b[1].reserve(a.size());
for (auto& v : a) {
b[v >> p & 1].push_back(v);
}
if (!b[0].empty() && !b[1].empty()) {
if (b[0].size() > b[1].size()) {
swap(b[0], b[1]);
}
Trie t(30); // todo
t.reserve(b[0].size());
for (auto& x : b[0]) {
t.insert(x);
}
T minXor = numeric_limits<T>::max();
for (auto& x : b[1]) {
minXor = min(minXor, t.minXor(x));
}
ans += minXor;
}
dfs(dfs, b[0], p - 1);
dfs(dfs, b[1], p - 1);
};
dfs(dfs, a, N - 1);
return ans;
}
};

用 hashmap 实现 Trie

AI 生成的,暂时还没有用过

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/**
* @brief Trie class, implementation of trie using hashmap in each trie node
* for all the characters of char16_t(UTF-16)type with methods to insert,
* delete, search, start with and to recommend words based on a given
* prefix.
*/
class Trie {
private:
struct Node {
// unordered map with key type char16_t and value is a shared pointer type of Node
std::unordered_map<char16_t, std::shared_ptr<Node>> children;
// boolean variable to represent the node end
bool word_end = false;
};

// declaring root node of trie
std::shared_ptr<Node> root_node = std::make_shared<Node>();

public:
Trie() = default;

void insert(const std::string& word) {
std::shared_ptr<Node> curr = root_node;
for (char ch : word) {
if (curr->children.find(ch) == curr->children.end()) {
curr->children[ch] = std::make_shared<Node>();
}
curr = curr->children[ch];
}
if (!curr->word_end && curr != root_node) {
curr->word_end = true;
}
}

bool search(const std::string& word) {
std::shared_ptr<Node> curr = root_node;
for (char ch : word) {
if (curr->children.find(ch) == curr->children.end()) {
return false;
}
curr = curr->children[ch];
if (!curr) {
return false;
}
}
return curr->word_end;
}

bool startwith(const std::string& prefix) {
std::shared_ptr<Node> curr = root_node;
for (char ch : prefix) {
if (curr->children.find(ch) == curr->children.end()) {
return false;
}
curr = curr->children[ch];
}
return true;
}

void delete_word(std::string word) {
std::shared_ptr<Node> curr = root_node;
std::stack<std::shared_ptr<Node>> nodes;
int cnt = 0;
for (char ch : word) {
if (curr->children.find(ch) == curr->children.end()) {
return;
}
if (curr->word_end) {
cnt++;
}

nodes.push(curr->children[ch]);
curr = curr->children[ch];
}
// Delete only when it's a word, and it has children after
// or prefix in the line
if (nodes.top()->word_end) {
nodes.top()->word_end = false;
}
// Delete only when it has no children after
// and also no prefix in the line
while (!(nodes.top()->word_end) && nodes.top()->children.empty()) {
nodes.pop();
nodes.top()->children.erase(word.back());
word.pop_back();
}
}

/**
* @brief helper function to predict/recommend words that starts with a
* given prefix from the end of prefix's node iterate through all the child
* nodes by recursively appending all the possible words below the trie
* @param prefix string to recommend the words
* @param element node at the end of the given prefix
* @param results list to store the all possible words
* @returns list of recommended words
*/
std::vector<std::string> get_all_words(std::vector<std::string> results, const std::shared_ptr<Node>& element, std::string prefix) {
if (element->word_end) {
results.push_back(prefix);
}
if (element->children.empty()) {
return results;
}
for (auto const& x : element->children) {
std::string key = "";
key = x.first;
prefix += key;
results = get_all_words(results, element->children[x.first], prefix);
prefix.pop_back();
}
return results;
}

// predict/recommend a word that starts with a given prefix and return a list of recommended words
std::vector<std::string> predict_words(const std::string& prefix) {
std::vector<std::string> result;
std::shared_ptr<Node> curr = root_node;
// traversing until the end of the given prefix in trie
for (char ch : prefix) {
if (curr->children.find(ch) == curr->children.end()) {
return result;
}
curr = curr->children[ch];
}
// if the given prefix is the only word without children
if (curr->word_end && curr->children.empty()) {
result.push_back(prefix);
return result;
}
// iteratively and recursively get the recommended words
return get_all_words(result, curr, prefix);
}
};