1// SPDX-License-Identifier: GPL-3.0-or-later
2
3#pragma once
4
5#include <atomic>
6#include <cstddef>
7#include <mos/type_utils.hpp>
8#include <mos/types.h>
9#include <stdnoreturn.h>
10#include <type_traits>
11
12// for C++, we need to use the atomic type directly
13typedef std::atomic_size_t atomic_t;
14
15template<class P, class M>
16[[gnu::always_inline]] constexpr size_t __offsetof(const M P::*member)
17{
18 return (size_t) &(reinterpret_cast<P *>(0)->*member);
19}
20
21template<class P, class M>
22[[gnu::always_inline]] constexpr inline P *__container_of(M *ptr, const M P::*member)
23{
24 return (P *) ((char *) ptr - __offsetof(member));
25}
26
27template<class P, class M>
28[[gnu::always_inline]] constexpr inline const P *__container_of(const M *ptr, const M P::*member)
29{
30 return (const P *) ((char *) ptr - __offsetof(member));
31}
32
33#define container_of(ptr, type, member) __container_of(ptr, &type::member)
34
35template<typename TOut, typename TIn>
36[[gnu::always_inline]] inline TOut *cast(TIn *value)
37{
38 return reinterpret_cast<TOut *>(value);
39}
40
41template<typename TOut, typename TIn>
42[[gnu::always_inline]] inline const TOut *cast(const TIn *value)
43{
44 return reinterpret_cast<const TOut *>(value);
45}
46
47struct PtrResultBase
48{
49 protected:
50 const int errorCode;
51
52 PtrResultBase() : errorCode(0) {};
53 PtrResultBase(int errorCode) : errorCode(errorCode) {};
54
55 public:
56 virtual bool isErr() const final
57 {
58 return errorCode != 0;
59 }
60
61 virtual long getErr() const final
62 {
63 return errorCode;
64 }
65
66 explicit operator bool() const
67 {
68 return !isErr();
69 }
70};
71
72template<typename T>
73struct PtrResult : public PtrResultBase
74{
75 private:
76 T *const value;
77
78 public:
79 PtrResult(T *value) : PtrResultBase(0), value(value) {};
80 PtrResult(int errorCode) : PtrResultBase(errorCode), value(nullptr) {};
81
82 template<typename U>
83 requires std::is_base_of_v<T, U> PtrResult(PtrResult<U> other) : PtrResultBase(other.getErr()), value(other.get()) {};
84
85 public:
86 std::add_lvalue_reference<T>::type operator*()
87 {
88 return *value;
89 }
90
91 const std::add_lvalue_reference<T>::type &operator*() const
92 {
93 return *value;
94 }
95
96 T *operator->()
97 {
98 return value;
99 }
100
101 const T *operator->() const
102 {
103 return value;
104 }
105
106 T *get() const
107 {
108 if (isErr())
109 mos::__raise_bad_ptrresult_value(errorCode);
110 return value;
111 }
112
113 bool operator==(const std::nullptr_t) const
114 {
115 return value == nullptr;
116 }
117
118 bool operator==(const PtrResult<T> &other) const
119 {
120 return value == other.value && errorCode == other.errorCode;
121 }
122
123 bool operator==(const T *other) const
124 {
125 return value == other;
126 }
127
128 explicit operator bool() const
129 {
130 return value != nullptr && !isErr();
131 }
132};
133
134template<>
135struct PtrResult<void> : public PtrResultBase
136{
137 public:
138 PtrResult() : PtrResultBase(0) {};
139 PtrResult(int errorCode) : PtrResultBase(errorCode) {};
140};
141
142template<typename E>
143requires std::is_enum_v<E> struct Flags
144{
145 private:
146 E value_;
147 static_assert(sizeof(E) <= sizeof(u32), "Flags only supports enums that fit into a u32");
148
149 using EnumType = E;
150 using U = std::underlying_type_t<E>;
151
152 public:
153 constexpr Flags(E value = static_cast<E>(0)) : value_(value)
154 {
155 static_assert(sizeof(Flags) == sizeof(E), "Flags must have the same size as the enum");
156 }
157
158 ~Flags() = default;
159
160 static Flags all()
161 {
162 return static_cast<E>(~static_cast<U>(0));
163 }
164
165 inline operator U() const
166 {
167 return static_cast<U>(value_);
168 }
169
170 inline Flags operator|(E e) const
171 {
172 return static_cast<E>(static_cast<U>(value_) | static_cast<U>(e));
173 }
174
175 inline Flags operator&(E b) const
176 {
177 return static_cast<E>(static_cast<U>(value_) & static_cast<U>(b));
178 }
179
180 inline Flags operator&(Flags b) const
181 {
182 return static_cast<E>(static_cast<U>(value_) & static_cast<U>(b.value_));
183 }
184
185 inline bool test(E b) const
186 {
187 return static_cast<U>(value_) & static_cast<U>(b);
188 }
189
190 inline bool test_inverse(E b) const
191 {
192 return static_cast<U>(value_) & ~static_cast<U>(b);
193 }
194
195 inline Flags erased(E b) const
196 {
197 return static_cast<E>(static_cast<U>(value_) & ~static_cast<U>(b));
198 }
199
200 inline Flags erased(Flags b) const
201 {
202 return static_cast<E>(static_cast<U>(value_) & ~static_cast<U>(b.value_));
203 }
204
205 inline Flags &operator|=(E b)
206 {
207 value_ = static_cast<E>(static_cast<U>(value_) | static_cast<U>(b));
208 return *this;
209 }
210
211 inline Flags &operator&=(E b)
212 {
213 value_ = static_cast<E>(static_cast<U>(value_) & static_cast<U>(b));
214 return *this;
215 }
216
217 inline Flags &operator&=(Flags b)
218 {
219 value_ = static_cast<E>(static_cast<U>(value_) & static_cast<U>(b.value_));
220 return *this;
221 }
222
223 inline Flags erase(E b)
224 {
225 value_ = static_cast<E>(static_cast<U>(value_) & ~static_cast<U>(b));
226 return *this;
227 }
228
229 inline Flags erase(Flags b)
230 {
231 value_ = static_cast<E>(static_cast<U>(value_) & ~static_cast<U>(b.value_));
232 return *this;
233 }
234};
235
236#define MOS_ENUM_FLAGS(enum, flags) using flags = Flags<enum>
237
238template<typename E>
239requires std::is_enum_v<E> constexpr Flags<E> operator|(E a, E b)
240{
241 using U = std::underlying_type_t<E>;
242 return static_cast<E>(static_cast<U>(a) | static_cast<U>(b));
243}
244
245template<typename E>
246requires std::is_enum_v<E> constexpr void operator~(E a) = delete;
247
248template<typename E>
249requires std::is_enum_v<E> constexpr void operator&(E a, E b) = delete;
250