1 | // SPDX-License-Identifier: MIT |
2 | // Adapted from https://github.com/managarm/frigg |
3 | |
4 | #pragma once |
5 | |
6 | #include "mos/assert.hpp" |
7 | |
8 | #include <algorithm> |
9 | #include <cstddef> |
10 | #include <functional> |
11 | #include <mos/allocator.hpp> |
12 | #include <mos/type_utils.hpp> |
13 | #include <mos_stdlib.hpp> |
14 | #include <optional> |
15 | |
16 | namespace mos |
17 | { |
18 | template<typename Key, typename Value, typename TAllocator = mos::default_allocator> |
19 | class HashMap |
20 | { |
21 | public: |
22 | typedef std::tuple<const Key, Value> entry_type; |
23 | |
24 | private: |
25 | struct chain : mos::NamedType<"HashMap.Chain" > |
26 | { |
27 | entry_type entry; |
28 | chain *next; |
29 | explicit chain(const Key &new_key, const Value &new_value) : entry(new_key, new_value), next(nullptr) {}; |
30 | explicit chain(const Key &new_key, Value &&new_value) : entry(new_key, std::move(new_value)), next(nullptr) {}; |
31 | }; |
32 | |
33 | public: |
34 | class iterator |
35 | { |
36 | friend class HashMap; |
37 | |
38 | public: |
39 | iterator &operator++() |
40 | { |
41 | MOS_ASSERT(item); |
42 | item = item->next; |
43 | if (item) |
44 | return *this; |
45 | |
46 | MOS_ASSERT(bucket < map->_capacity); |
47 | while (true) |
48 | { |
49 | bucket++; |
50 | if (bucket == map->_capacity) |
51 | break; |
52 | item = map->_table[bucket]; |
53 | if (item) |
54 | break; |
55 | } |
56 | |
57 | return *this; |
58 | } |
59 | |
60 | bool operator==(const iterator &other) const |
61 | { |
62 | return (bucket == other.bucket) && (item == other.item); |
63 | } |
64 | |
65 | entry_type &operator*() |
66 | { |
67 | return item->entry; |
68 | } |
69 | entry_type *operator->() |
70 | { |
71 | return &item->entry; |
72 | } |
73 | |
74 | operator bool() |
75 | { |
76 | return item != nullptr; |
77 | } |
78 | |
79 | private: |
80 | iterator(HashMap *map, size_t bucket, chain *item) : map(map), item(item), bucket(bucket) {}; |
81 | HashMap *map; |
82 | chain *item; |
83 | size_t bucket; |
84 | }; |
85 | |
86 | class const_iterator |
87 | { |
88 | friend class HashMap; |
89 | |
90 | public: |
91 | const_iterator &operator++() |
92 | { |
93 | MOS_ASSERT(item); |
94 | item = item->next; |
95 | if (item) |
96 | return *this; |
97 | |
98 | MOS_ASSERT(bucket < map->_capacity); |
99 | while (true) |
100 | { |
101 | bucket++; |
102 | if (bucket == map->_capacity) |
103 | break; |
104 | item = map->_table[bucket]; |
105 | if (item) |
106 | break; |
107 | } |
108 | |
109 | return *this; |
110 | } |
111 | |
112 | bool operator==(const const_iterator &other) const |
113 | { |
114 | return (bucket == other.bucket) && (item == other.item); |
115 | } |
116 | |
117 | const entry_type &operator*() const |
118 | { |
119 | return item->entry; |
120 | } |
121 | |
122 | const entry_type *operator->() const |
123 | { |
124 | return &item->entry; |
125 | } |
126 | |
127 | operator bool() const |
128 | { |
129 | return item != nullptr; |
130 | } |
131 | |
132 | private: |
133 | const_iterator(const HashMap *map, size_t bucket, const chain *item) : map(map), item(item), bucket(bucket) {}; |
134 | const HashMap *map; |
135 | const chain *item; |
136 | size_t bucket; |
137 | }; |
138 | |
139 | constexpr HashMap() : _table(nullptr), _capacity(0), _size(0) {}; |
140 | HashMap(std::initializer_list<entry_type> init) : _table(nullptr), _capacity(0), _size(0) |
141 | { |
142 | /* TODO: we know the size so we don't have to keep rehashing?? */ |
143 | for (auto &[key, value] : init) |
144 | insert(key, value); |
145 | } |
146 | |
147 | ~HashMap() |
148 | { |
149 | for (size_t i = 0; i < _capacity; i++) |
150 | { |
151 | chain *item = _table[i]; |
152 | while (item != nullptr) |
153 | { |
154 | chain *next = item->next; |
155 | delete item; |
156 | item = next; |
157 | } |
158 | } |
159 | |
160 | TAllocator::free(_table, sizeof(chain *) * _capacity); |
161 | } |
162 | |
163 | HashMap(const HashMap &) = delete; |
164 | |
165 | void insert(const Key &key, const Value &value); |
166 | void insert(const Key &key, Value &&value); |
167 | Value &operator[](const Key &key); |
168 | |
169 | bool empty() |
170 | { |
171 | return !_size; |
172 | } |
173 | |
174 | iterator end() |
175 | { |
176 | return iterator(this, _capacity, nullptr); |
177 | } |
178 | |
179 | iterator find(const Key &key) |
180 | { |
181 | if (!_size) |
182 | return end(); |
183 | |
184 | const auto bucket = (std::hash<Key>{}(key) % _capacity); |
185 | for (chain *item = _table[bucket]; item != nullptr; item = item->next) |
186 | { |
187 | if (std::get<0>(item->entry) == key) |
188 | return iterator(this, bucket, item); |
189 | } |
190 | |
191 | return end(); |
192 | } |
193 | |
194 | iterator begin() |
195 | { |
196 | if (!_size) |
197 | return iterator(this, _capacity, nullptr); |
198 | |
199 | for (size_t bucket = 0; bucket < _capacity; bucket++) |
200 | { |
201 | if (_table[bucket]) |
202 | return iterator(this, bucket, _table[bucket]); |
203 | } |
204 | |
205 | MOS_ASSERT(!"hash_map corrupted" ); |
206 | MOS_UNREACHABLE(); |
207 | } |
208 | |
209 | const_iterator end() const |
210 | { |
211 | return const_iterator(this, _capacity, nullptr); |
212 | } |
213 | |
214 | const_iterator find(const Key &key) const |
215 | { |
216 | if (!_size) |
217 | return end(); |
218 | |
219 | const auto bucket = (std::hash<Key>{}(key)) % _capacity; |
220 | for (const chain *item = _table[bucket]; item != nullptr; item = item->next) |
221 | { |
222 | if (std::get<0>(item->entry) == key) |
223 | return const_iterator(this, bucket, item); |
224 | } |
225 | |
226 | return end(); |
227 | } |
228 | |
229 | std::optional<Value> get(const Key &key); |
230 | std::optional<Value> remove(const Key &key); |
231 | |
232 | size_t size() const |
233 | { |
234 | return _size; |
235 | } |
236 | |
237 | private: |
238 | void rehash(); |
239 | |
240 | chain **_table; |
241 | size_t _capacity; |
242 | size_t _size; |
243 | }; |
244 | |
245 | template<typename Key, typename Value, typename TAllocator> |
246 | void HashMap<Key, Value, TAllocator>::insert(const Key &key, const Value &value) |
247 | { |
248 | if (_size >= _capacity) |
249 | rehash(); |
250 | |
251 | MOS_ASSERT(_capacity > 0); |
252 | const auto bucket = (std::hash<Key>{}(key)) % _capacity; |
253 | |
254 | auto item = mos::create<chain>(key, value); |
255 | item->next = _table[bucket]; |
256 | _table[bucket] = item; |
257 | _size++; |
258 | } |
259 | |
260 | template<typename Key, typename Value, typename TAllocator> |
261 | void HashMap<Key, Value, TAllocator>::insert(const Key &key, Value &&value) |
262 | { |
263 | if (_size >= _capacity) |
264 | rehash(); |
265 | |
266 | MOS_ASSERT(_capacity > 0); |
267 | const auto bucket = (std::hash<Key>{}(key)) % _capacity; |
268 | |
269 | auto item = mos::create<chain>(key, std::move(value)); |
270 | item->next = _table[bucket]; |
271 | _table[bucket] = item; |
272 | _size++; |
273 | } |
274 | |
275 | template<typename Key, typename Value, typename TAllocator> |
276 | Value &HashMap<Key, Value, TAllocator>::operator[](const Key &key) |
277 | { |
278 | /* empty map case */ |
279 | if (_size == 0) |
280 | { |
281 | rehash(); |
282 | const auto bucket = (std::hash<Key>{}(key)) % _capacity; |
283 | auto item = mos::create<chain>(key, Value{}); |
284 | item->next = _table[bucket]; |
285 | _table[bucket] = item; |
286 | _size++; |
287 | } |
288 | |
289 | const auto bucket = (std::hash<Key>{}(key)) % _capacity; |
290 | for (chain *item = _table[bucket]; item != nullptr; item = item->next) |
291 | { |
292 | if (std::get<0>(item->entry) == key) |
293 | return std::get<1>(item->entry); |
294 | } |
295 | |
296 | if (_size >= _capacity) |
297 | rehash(); |
298 | |
299 | auto item = mos::create<chain>(key, Value{}); |
300 | item->next = _table[bucket]; |
301 | _table[bucket] = item; |
302 | _size++; |
303 | return std::get<1>(item->entry); |
304 | } |
305 | |
306 | template<typename Key, typename Value, typename TAllocator> |
307 | std::optional<Value> HashMap<Key, Value, TAllocator>::get(const Key &key) |
308 | { |
309 | if (_size == 0) |
310 | return std::nullopt; |
311 | |
312 | const auto bucket = (std::hash<Key>{}(key)) % _capacity; |
313 | |
314 | for (chain *item = _table[bucket]; item != nullptr; item = item->next) |
315 | { |
316 | if (std::get<0>(item->entry) == key) |
317 | return std::get<1>(item->entry); |
318 | } |
319 | |
320 | return std::nullopt; |
321 | } |
322 | |
323 | template<typename Key, typename Value, typename TAllocator> |
324 | std::optional<Value> HashMap<Key, Value, TAllocator>::remove(const Key &key) |
325 | { |
326 | if (_size == 0) |
327 | return std::nullopt; |
328 | |
329 | const auto bucket = (std::hash<Key>{}(key)) % _capacity; |
330 | |
331 | chain *previous = nullptr; |
332 | for (chain *item = _table[bucket]; item != nullptr; item = item->next) |
333 | { |
334 | if (std::get<0>(item->entry) == key) |
335 | { |
336 | Value value = std::move(std::get<1>(item->entry)); |
337 | |
338 | if (previous == nullptr) |
339 | _table[bucket] = item->next; |
340 | else |
341 | previous->next = item->next; |
342 | delete item; |
343 | _size--; |
344 | return value; |
345 | } |
346 | |
347 | previous = item; |
348 | } |
349 | |
350 | return std::nullopt; |
351 | } |
352 | |
353 | template<typename Key, typename Value, typename TAllocator> |
354 | void HashMap<Key, Value, TAllocator>::rehash() |
355 | { |
356 | const size_t new_capacity = std::max(a: 2 * _size, b: 10lu); |
357 | |
358 | const auto new_table = kcalloc<chain *>(new_capacity); |
359 | for (size_t i = 0; i < new_capacity; i++) |
360 | new_table[i] = nullptr; |
361 | |
362 | for (size_t i = 0; i < _capacity; i++) |
363 | { |
364 | auto item = _table[i]; |
365 | while (item != nullptr) |
366 | { |
367 | const auto &key = std::get<0>(item->entry); |
368 | const auto bucket = (std::hash<Key>{}(key)) % new_capacity; |
369 | |
370 | const auto next = item->next; |
371 | item->next = new_table[bucket]; |
372 | new_table[bucket] = item; |
373 | item = next; |
374 | } |
375 | } |
376 | |
377 | TAllocator::free(_table, sizeof(chain *) * _capacity); |
378 | _table = new_table; |
379 | _capacity = new_capacity; |
380 | } |
381 | } // namespace mos |
382 | |