1// SPDX-License-Identifier: MIT
2// Adapted from https://github.com/managarm/frigg
3
4#pragma once
5
6#include "mos/assert.hpp"
7
8#include <algorithm>
9#include <cstddef>
10#include <functional>
11#include <mos/allocator.hpp>
12#include <mos/type_utils.hpp>
13#include <mos_stdlib.hpp>
14#include <optional>
15
16namespace mos
17{
18 template<typename Key, typename Value>
19 class HashMap
20 {
21 public:
22 typedef std::tuple<const Key, Value> entry_type;
23
24 private:
25 struct chain : mos::NamedType<"HashMap.Chain">
26 {
27 entry_type entry;
28 chain *next;
29 explicit chain(const Key &new_key, const Value &new_value) : entry(new_key, new_value), next(nullptr) {};
30 explicit chain(const Key &new_key, Value &&new_value) : entry(new_key, std::move(new_value)), next(nullptr) {};
31 };
32 using ChainAllocator = mos::default_allocator<chain *>;
33
34 public:
35 class iterator
36 {
37 friend class HashMap;
38
39 public:
40 iterator &operator++()
41 {
42 MOS_ASSERT(item);
43 item = item->next;
44 if (item)
45 return *this;
46
47 MOS_ASSERT(bucket < map->_capacity);
48 while (true)
49 {
50 bucket++;
51 if (bucket == map->_capacity)
52 break;
53 item = map->_table[bucket];
54 if (item)
55 break;
56 }
57
58 return *this;
59 }
60
61 bool operator==(const iterator &other) const
62 {
63 return (bucket == other.bucket) && (item == other.item);
64 }
65
66 entry_type &operator*()
67 {
68 return item->entry;
69 }
70 entry_type *operator->()
71 {
72 return &item->entry;
73 }
74
75 operator bool()
76 {
77 return item != nullptr;
78 }
79
80 private:
81 iterator(HashMap *map, size_t bucket, chain *item) : map(map), item(item), bucket(bucket) {};
82 HashMap *map;
83 chain *item;
84 size_t bucket;
85 };
86
87 class const_iterator
88 {
89 friend class HashMap;
90
91 public:
92 const_iterator &operator++()
93 {
94 MOS_ASSERT(item);
95 item = item->next;
96 if (item)
97 return *this;
98
99 MOS_ASSERT(bucket < map->_capacity);
100 while (true)
101 {
102 bucket++;
103 if (bucket == map->_capacity)
104 break;
105 item = map->_table[bucket];
106 if (item)
107 break;
108 }
109
110 return *this;
111 }
112
113 bool operator==(const const_iterator &other) const
114 {
115 return (bucket == other.bucket) && (item == other.item);
116 }
117
118 const entry_type &operator*() const
119 {
120 return item->entry;
121 }
122
123 const entry_type *operator->() const
124 {
125 return &item->entry;
126 }
127
128 operator bool() const
129 {
130 return item != nullptr;
131 }
132
133 private:
134 const_iterator(const HashMap *map, size_t bucket, const chain *item) : map(map), item(item), bucket(bucket) {};
135 const HashMap *map;
136 const chain *item;
137 size_t bucket;
138 };
139
140 constexpr HashMap() : _table(nullptr), _capacity(0), _size(0) {};
141 HashMap(std::initializer_list<entry_type> init) : _table(nullptr), _capacity(0), _size(0)
142 {
143 /* TODO: we know the size so we don't have to keep rehashing?? */
144 for (auto &[key, value] : init)
145 insert(key, value);
146 }
147
148 ~HashMap()
149 {
150 for (size_t i = 0; i < _capacity; i++)
151 {
152 chain *item = _table[i];
153 while (item != nullptr)
154 {
155 chain *next = item->next;
156 delete item;
157 item = next;
158 }
159 }
160
161 ChainAllocator::free(_table, sizeof(chain *) * _capacity);
162 }
163
164 HashMap(const HashMap &) = delete;
165
166 void insert(const Key &key, const Value &value);
167 void insert(const Key &key, Value &&value);
168 Value &operator[](const Key &key);
169
170 bool empty()
171 {
172 return !_size;
173 }
174
175 iterator end()
176 {
177 return iterator(this, _capacity, nullptr);
178 }
179
180 iterator find(const Key &key)
181 {
182 if (!_size)
183 return end();
184
185 const auto bucket = (std::hash<Key>{}(key) % _capacity);
186 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
187 {
188 if (std::get<0>(item->entry) == key)
189 return iterator(this, bucket, item);
190 }
191
192 return end();
193 }
194
195 iterator begin()
196 {
197 if (!_size)
198 return iterator(this, _capacity, nullptr);
199
200 for (size_t bucket = 0; bucket < _capacity; bucket++)
201 {
202 if (_table[bucket])
203 return iterator(this, bucket, _table[bucket]);
204 }
205
206 MOS_ASSERT(!"hash_map corrupted");
207 MOS_UNREACHABLE();
208 }
209
210 const_iterator end() const
211 {
212 return const_iterator(this, _capacity, nullptr);
213 }
214
215 const_iterator find(const Key &key) const
216 {
217 if (!_size)
218 return end();
219
220 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
221 for (const chain *item = _table[bucket]; item != nullptr; item = item->next)
222 {
223 if (std::get<0>(item->entry) == key)
224 return const_iterator(this, bucket, item);
225 }
226
227 return end();
228 }
229
230 std::optional<Value> get(const Key &key);
231 std::optional<Value> remove(const Key &key);
232
233 size_t size() const
234 {
235 return _size;
236 }
237
238 private:
239 void rehash();
240
241 chain **_table;
242 size_t _capacity;
243 size_t _size;
244 };
245
246 template<typename Key, typename Value>
247 void HashMap<Key, Value>::insert(const Key &key, const Value &value)
248 {
249 if (_size >= _capacity)
250 rehash();
251
252 MOS_ASSERT(_capacity > 0);
253 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
254
255 auto item = mos::create<chain>(key, value);
256 item->next = _table[bucket];
257 _table[bucket] = item;
258 _size++;
259 }
260
261 template<typename Key, typename Value>
262 void HashMap<Key, Value>::insert(const Key &key, Value &&value)
263 {
264 if (_size >= _capacity)
265 rehash();
266
267 MOS_ASSERT(_capacity > 0);
268 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
269
270 auto item = mos::create<chain>(key, std::move(value));
271 item->next = _table[bucket];
272 _table[bucket] = item;
273 _size++;
274 }
275
276 template<typename Key, typename Value>
277 Value &HashMap<Key, Value>::operator[](const Key &key)
278 {
279 /* empty map case */
280 if (_size == 0)
281 {
282 rehash();
283 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
284 auto item = mos::create<chain>(key, Value{});
285 item->next = _table[bucket];
286 _table[bucket] = item;
287 _size++;
288 }
289
290 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
291 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
292 {
293 if (std::get<0>(item->entry) == key)
294 return std::get<1>(item->entry);
295 }
296
297 if (_size >= _capacity)
298 rehash();
299
300 auto item = mos::create<chain>(key, Value{});
301 item->next = _table[bucket];
302 _table[bucket] = item;
303 _size++;
304 return std::get<1>(item->entry);
305 }
306
307 template<typename Key, typename Value>
308 std::optional<Value> HashMap<Key, Value>::get(const Key &key)
309 {
310 if (_size == 0)
311 return std::nullopt;
312
313 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
314
315 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
316 {
317 if (std::get<0>(item->entry) == key)
318 return std::get<1>(item->entry);
319 }
320
321 return std::nullopt;
322 }
323
324 template<typename Key, typename Value>
325 std::optional<Value> HashMap<Key, Value>::remove(const Key &key)
326 {
327 if (_size == 0)
328 return std::nullopt;
329
330 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
331
332 chain *previous = nullptr;
333 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
334 {
335 if (std::get<0>(item->entry) == key)
336 {
337 Value value = std::move(std::get<1>(item->entry));
338
339 if (previous == nullptr)
340 _table[bucket] = item->next;
341 else
342 previous->next = item->next;
343 delete item;
344 _size--;
345 return value;
346 }
347
348 previous = item;
349 }
350
351 return std::nullopt;
352 }
353
354 template<typename Key, typename Value>
355 void HashMap<Key, Value>::rehash()
356 {
357 const size_t new_capacity = std::max(a: 2 * _size, b: 10lu);
358
359 const auto new_table = kcalloc<chain *>(new_capacity);
360 for (size_t i = 0; i < new_capacity; i++)
361 new_table[i] = nullptr;
362
363 for (size_t i = 0; i < _capacity; i++)
364 {
365 auto item = _table[i];
366 while (item != nullptr)
367 {
368 const auto &key = std::get<0>(item->entry);
369 const auto bucket = (std::hash<Key>{}(key)) % new_capacity;
370
371 const auto next = item->next;
372 item->next = new_table[bucket];
373 new_table[bucket] = item;
374 item = next;
375 }
376 }
377
378 ChainAllocator::free(_table, sizeof(chain *) * _capacity);
379 _table = new_table;
380 _capacity = new_capacity;
381 }
382} // namespace mos
383