1// SPDX-License-Identifier: MIT
2// Adapted from https://github.com/managarm/frigg
3
4#pragma once
5
6#include "mos/assert.hpp"
7
8#include <algorithm>
9#include <cstddef>
10#include <functional>
11#include <mos/allocator.hpp>
12#include <mos/type_utils.hpp>
13#include <mos_stdlib.hpp>
14#include <optional>
15
16namespace mos
17{
18 template<typename Key, typename Value, typename TAllocator = mos::default_allocator>
19 class HashMap
20 {
21 public:
22 typedef std::tuple<const Key, Value> entry_type;
23
24 private:
25 struct chain : mos::NamedType<"HashMap.Chain">
26 {
27 entry_type entry;
28 chain *next;
29 explicit chain(const Key &new_key, const Value &new_value) : entry(new_key, new_value), next(nullptr) {};
30 explicit chain(const Key &new_key, Value &&new_value) : entry(new_key, std::move(new_value)), next(nullptr) {};
31 };
32
33 public:
34 class iterator
35 {
36 friend class HashMap;
37
38 public:
39 iterator &operator++()
40 {
41 MOS_ASSERT(item);
42 item = item->next;
43 if (item)
44 return *this;
45
46 MOS_ASSERT(bucket < map->_capacity);
47 while (true)
48 {
49 bucket++;
50 if (bucket == map->_capacity)
51 break;
52 item = map->_table[bucket];
53 if (item)
54 break;
55 }
56
57 return *this;
58 }
59
60 bool operator==(const iterator &other) const
61 {
62 return (bucket == other.bucket) && (item == other.item);
63 }
64
65 entry_type &operator*()
66 {
67 return item->entry;
68 }
69 entry_type *operator->()
70 {
71 return &item->entry;
72 }
73
74 operator bool()
75 {
76 return item != nullptr;
77 }
78
79 private:
80 iterator(HashMap *map, size_t bucket, chain *item) : map(map), item(item), bucket(bucket) {};
81 HashMap *map;
82 chain *item;
83 size_t bucket;
84 };
85
86 class const_iterator
87 {
88 friend class HashMap;
89
90 public:
91 const_iterator &operator++()
92 {
93 MOS_ASSERT(item);
94 item = item->next;
95 if (item)
96 return *this;
97
98 MOS_ASSERT(bucket < map->_capacity);
99 while (true)
100 {
101 bucket++;
102 if (bucket == map->_capacity)
103 break;
104 item = map->_table[bucket];
105 if (item)
106 break;
107 }
108
109 return *this;
110 }
111
112 bool operator==(const const_iterator &other) const
113 {
114 return (bucket == other.bucket) && (item == other.item);
115 }
116
117 const entry_type &operator*() const
118 {
119 return item->entry;
120 }
121
122 const entry_type *operator->() const
123 {
124 return &item->entry;
125 }
126
127 operator bool() const
128 {
129 return item != nullptr;
130 }
131
132 private:
133 const_iterator(const HashMap *map, size_t bucket, const chain *item) : map(map), item(item), bucket(bucket) {};
134 const HashMap *map;
135 const chain *item;
136 size_t bucket;
137 };
138
139 constexpr HashMap() : _table(nullptr), _capacity(0), _size(0) {};
140 HashMap(std::initializer_list<entry_type> init) : _table(nullptr), _capacity(0), _size(0)
141 {
142 /* TODO: we know the size so we don't have to keep rehashing?? */
143 for (auto &[key, value] : init)
144 insert(key, value);
145 }
146
147 ~HashMap()
148 {
149 for (size_t i = 0; i < _capacity; i++)
150 {
151 chain *item = _table[i];
152 while (item != nullptr)
153 {
154 chain *next = item->next;
155 delete item;
156 item = next;
157 }
158 }
159
160 TAllocator::free(_table, sizeof(chain *) * _capacity);
161 }
162
163 HashMap(const HashMap &) = delete;
164
165 void insert(const Key &key, const Value &value);
166 void insert(const Key &key, Value &&value);
167 Value &operator[](const Key &key);
168
169 bool empty()
170 {
171 return !_size;
172 }
173
174 iterator end()
175 {
176 return iterator(this, _capacity, nullptr);
177 }
178
179 iterator find(const Key &key)
180 {
181 if (!_size)
182 return end();
183
184 const auto bucket = (std::hash<Key>{}(key) % _capacity);
185 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
186 {
187 if (std::get<0>(item->entry) == key)
188 return iterator(this, bucket, item);
189 }
190
191 return end();
192 }
193
194 iterator begin()
195 {
196 if (!_size)
197 return iterator(this, _capacity, nullptr);
198
199 for (size_t bucket = 0; bucket < _capacity; bucket++)
200 {
201 if (_table[bucket])
202 return iterator(this, bucket, _table[bucket]);
203 }
204
205 MOS_ASSERT(!"hash_map corrupted");
206 MOS_UNREACHABLE();
207 }
208
209 const_iterator end() const
210 {
211 return const_iterator(this, _capacity, nullptr);
212 }
213
214 const_iterator find(const Key &key) const
215 {
216 if (!_size)
217 return end();
218
219 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
220 for (const chain *item = _table[bucket]; item != nullptr; item = item->next)
221 {
222 if (std::get<0>(item->entry) == key)
223 return const_iterator(this, bucket, item);
224 }
225
226 return end();
227 }
228
229 std::optional<Value> get(const Key &key);
230 std::optional<Value> remove(const Key &key);
231
232 size_t size() const
233 {
234 return _size;
235 }
236
237 private:
238 void rehash();
239
240 chain **_table;
241 size_t _capacity;
242 size_t _size;
243 };
244
245 template<typename Key, typename Value, typename TAllocator>
246 void HashMap<Key, Value, TAllocator>::insert(const Key &key, const Value &value)
247 {
248 if (_size >= _capacity)
249 rehash();
250
251 MOS_ASSERT(_capacity > 0);
252 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
253
254 auto item = mos::create<chain>(key, value);
255 item->next = _table[bucket];
256 _table[bucket] = item;
257 _size++;
258 }
259
260 template<typename Key, typename Value, typename TAllocator>
261 void HashMap<Key, Value, TAllocator>::insert(const Key &key, Value &&value)
262 {
263 if (_size >= _capacity)
264 rehash();
265
266 MOS_ASSERT(_capacity > 0);
267 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
268
269 auto item = mos::create<chain>(key, std::move(value));
270 item->next = _table[bucket];
271 _table[bucket] = item;
272 _size++;
273 }
274
275 template<typename Key, typename Value, typename TAllocator>
276 Value &HashMap<Key, Value, TAllocator>::operator[](const Key &key)
277 {
278 /* empty map case */
279 if (_size == 0)
280 {
281 rehash();
282 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
283 auto item = mos::create<chain>(key, Value{});
284 item->next = _table[bucket];
285 _table[bucket] = item;
286 _size++;
287 }
288
289 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
290 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
291 {
292 if (std::get<0>(item->entry) == key)
293 return std::get<1>(item->entry);
294 }
295
296 if (_size >= _capacity)
297 rehash();
298
299 auto item = mos::create<chain>(key, Value{});
300 item->next = _table[bucket];
301 _table[bucket] = item;
302 _size++;
303 return std::get<1>(item->entry);
304 }
305
306 template<typename Key, typename Value, typename TAllocator>
307 std::optional<Value> HashMap<Key, Value, TAllocator>::get(const Key &key)
308 {
309 if (_size == 0)
310 return std::nullopt;
311
312 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
313
314 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
315 {
316 if (std::get<0>(item->entry) == key)
317 return std::get<1>(item->entry);
318 }
319
320 return std::nullopt;
321 }
322
323 template<typename Key, typename Value, typename TAllocator>
324 std::optional<Value> HashMap<Key, Value, TAllocator>::remove(const Key &key)
325 {
326 if (_size == 0)
327 return std::nullopt;
328
329 const auto bucket = (std::hash<Key>{}(key)) % _capacity;
330
331 chain *previous = nullptr;
332 for (chain *item = _table[bucket]; item != nullptr; item = item->next)
333 {
334 if (std::get<0>(item->entry) == key)
335 {
336 Value value = std::move(std::get<1>(item->entry));
337
338 if (previous == nullptr)
339 _table[bucket] = item->next;
340 else
341 previous->next = item->next;
342 delete item;
343 _size--;
344 return value;
345 }
346
347 previous = item;
348 }
349
350 return std::nullopt;
351 }
352
353 template<typename Key, typename Value, typename TAllocator>
354 void HashMap<Key, Value, TAllocator>::rehash()
355 {
356 const size_t new_capacity = std::max(a: 2 * _size, b: 10lu);
357
358 const auto new_table = kcalloc<chain *>(new_capacity);
359 for (size_t i = 0; i < new_capacity; i++)
360 new_table[i] = nullptr;
361
362 for (size_t i = 0; i < _capacity; i++)
363 {
364 auto item = _table[i];
365 while (item != nullptr)
366 {
367 const auto &key = std::get<0>(item->entry);
368 const auto bucket = (std::hash<Key>{}(key)) % new_capacity;
369
370 const auto next = item->next;
371 item->next = new_table[bucket];
372 new_table[bucket] = item;
373 item = next;
374 }
375 }
376
377 TAllocator::free(_table, sizeof(chain *) * _capacity);
378 _table = new_table;
379 _capacity = new_capacity;
380 }
381} // namespace mos
382