| 1 | #ifndef RUN_H |
| 2 | #define RUN_H |
| 3 | #include <algorithm> |
| 4 | #include <cctype> |
| 5 | #include <chrono> |
| 6 | #include <cmath> |
| 7 | #include <cstddef> |
| 8 | #include <cstdlib> |
| 9 | #include <cstring> |
| 10 | #include <ctime> |
| 11 | #include <fcntl.h> |
| 12 | #include <fstream> |
| 13 | #include <iostream> |
| 14 | #include <memory> |
| 15 | #include <vector> |
| 16 | |
| 17 | #if defined _WIN32 |
| 18 | #include "win.h" |
| 19 | #else |
| 20 | #include <sys/mman.h> |
| 21 | #include <unistd.h> |
| 22 | #endif |
| 23 | |
| 24 | extern int GS; |
| 25 | |
| 26 | typedef struct |
| 27 | { |
| 28 | std::unique_ptr<int8_t[]> q; // quantized values |
| 29 | std::unique_ptr<float[]> s; // scaling factors |
| 30 | } QuantizedTensor; |
| 31 | |
| 32 | class Config |
| 33 | { |
| 34 | public: |
| 35 | int dim; // transformer dimension |
| 36 | int hidden_dim; // for ffn layers |
| 37 | int n_layers; // number of layers |
| 38 | int n_heads; // number of query heads |
| 39 | int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) |
| 40 | int vocab_size; // vocabulary size, usually 256 (byte-level) |
| 41 | int seq_len; // max sequence length |
| 42 | }; |
| 43 | |
| 44 | template<typename T> |
| 45 | class TransformerWeights |
| 46 | { |
| 47 | public: |
| 48 | // token embedding table |
| 49 | std::unique_ptr<float[]> token_embedding_table; // (vocab_size, dim) |
| 50 | // final rmsnorm |
| 51 | std::unique_ptr<float[]> rms_final_weight; // (dim,) |
| 52 | // (optional) classifier weights for the logits, on the last layer |
| 53 | // weights for rmsnorms |
| 54 | std::unique_ptr<float[]> rms_att_weight; // (layer, dim) rmsnorm weights |
| 55 | std::unique_ptr<float[]> rms_ffn_weight; // (layer, dim) |
| 56 | // weights for matmuls. note dim == n_heads * head_size |
| 57 | std::unique_ptr<T[]> wq; // (layer, dim, n_heads * head_size) |
| 58 | std::unique_ptr<T[]> wk; // (layer, dim, n_kv_heads * head_size) |
| 59 | std::unique_ptr<T[]> wv; // (layer, dim, n_kv_heads * head_size) |
| 60 | std::unique_ptr<T[]> wo; // (layer, n_heads * head_size, dim) |
| 61 | // weights for ffn |
| 62 | std::unique_ptr<T[]> w1; // (layer, hidden_dim, dim) |
| 63 | std::unique_ptr<T[]> w2; // (layer, dim, hidden_dim) |
| 64 | std::unique_ptr<T[]> w3; // (layer, hidden_dim, dim) |
| 65 | std::unique_ptr<T[]> wcls; |
| 66 | // tensor2d freq_cis_real; // [seq_len, (dim/n_heads)/2] |
| 67 | // tensor2d freq_cis_imag; // [seq_len, (dim/n_heads)/2] |
| 68 | std::unique_ptr<T[]> q_tokens; // (vocab_size, dim) |
| 69 | }; |
| 70 | |
| 71 | template<typename T> |
| 72 | class RunState |
| 73 | { |
| 74 | public: |
| 75 | // current wave of activations |
| 76 | std::unique_ptr<float[]> x; // activation at current time stamp (dim,) |
| 77 | std::unique_ptr<float[]> xb; // same, but inside a residual branch (dim,) |
| 78 | std::unique_ptr<float[]> xb2; // an additional buffer just for convenience (dim,) |
| 79 | std::unique_ptr<float[]> hb; // buffer for hidden dimension in the ffn (hidden_dim,) |
| 80 | std::unique_ptr<float[]> hb2; // buffer for hidden dimension in the ffn (hidden_dim,) |
| 81 | std::unique_ptr<float[]> q; // query (dim,) |
| 82 | std::unique_ptr<float[]> k; // key (dim,) |
| 83 | std::unique_ptr<float[]> v; // value (dim,) |
| 84 | std::unique_ptr<float[]> att; // buffer for scores/attention values (n_heads, seq_len) |
| 85 | std::unique_ptr<float[]> logits; // output logits |
| 86 | // kv cache |
| 87 | std::unique_ptr<float[]> key_cache; // (layer, seq_len, dim) |
| 88 | std::unique_ptr<float[]> value_cache; // (layer, seq_len, dim) |
| 89 | }; |
| 90 | |
| 91 | template<> |
| 92 | class RunState<float> |
| 93 | { |
| 94 | public: |
| 95 | // current wave of activations |
| 96 | std::unique_ptr<float[]> x; // activation at current time stamp (dim,) |
| 97 | std::unique_ptr<float[]> xb; // same, but inside a residual branch (dim,) |
| 98 | std::unique_ptr<float[]> xb2; // an additional buffer just for convenience (dim,) |
| 99 | std::unique_ptr<float[]> hb; // buffer for hidden dimension in the ffn (hidden_dim,) |
| 100 | std::unique_ptr<float[]> hb2; // buffer for hidden dimension in the ffn (hidden_dim,) |
| 101 | std::unique_ptr<float[]> q; // query (dim,) |
| 102 | std::unique_ptr<float[]> k; // key (dim,) |
| 103 | std::unique_ptr<float[]> v; // value (dim,) |
| 104 | std::unique_ptr<float[]> att; // buffer for scores/attention values (n_heads, seq_len) |
| 105 | std::unique_ptr<float[]> logits; // output logits |
| 106 | // kv cache |
| 107 | std::unique_ptr<float[]> key_cache; // (layer, seq_len, dim) |
| 108 | std::unique_ptr<float[]> value_cache; // (layer, seq_len, dim) |
| 109 | }; |
| 110 | |
| 111 | template<> |
| 112 | class RunState<QuantizedTensor> |
| 113 | { |
| 114 | public: |
| 115 | // current wave of activations |
| 116 | std::unique_ptr<float[]> x; // activation at current time stamp (dim,) |
| 117 | std::unique_ptr<float[]> xb; // same, but inside a residual branch (dim,) |
| 118 | std::unique_ptr<float[]> xb2; // an additional buffer just for convenience (dim,) |
| 119 | std::unique_ptr<float[]> hb; // buffer for hidden dimension in the ffn (hidden_dim,) |
| 120 | std::unique_ptr<float[]> hb2; // buffer for hidden dimension in the ffn (hidden_dim,) |
| 121 | std::unique_ptr<float[]> q; // query (dim,) |
| 122 | std::unique_ptr<float[]> k; // key (dim,) |
| 123 | std::unique_ptr<float[]> v; // value (dim,) |
| 124 | std::unique_ptr<float[]> att; // buffer for scores/attention values (n_heads, seq_len) |
| 125 | std::unique_ptr<float[]> logits; // output logits |
| 126 | // kv cache |
| 127 | std::unique_ptr<float[]> key_cache; // (layer, seq_len, dim) |
| 128 | std::unique_ptr<float[]> value_cache; // (layer, seq_len, dim) |
| 129 | |
| 130 | std::unique_ptr<QuantizedTensor[]> xq; // quantized x (dim,) |
| 131 | std::unique_ptr<QuantizedTensor[]> hq; // quantized hb (hidden_dim,) |
| 132 | }; |
| 133 | |
| 134 | typedef struct |
| 135 | { |
| 136 | std::string str; |
| 137 | int id; |
| 138 | } TokenIndex; |
| 139 | |
| 140 | inline bool compare_tokens(const TokenIndex &a, const TokenIndex &b) |
| 141 | { |
| 142 | return a.str < b.str; |
| 143 | } |
| 144 | |
| 145 | inline int str_lookup(const std::string &str, const std::unique_ptr<TokenIndex[]> &sorted_vocab, int vocab_size) |
| 146 | { |
| 147 | // efficiently find the perfect match for str in vocab, return its index or -1 if not found |
| 148 | TokenIndex tok = { .str = str }; // acts as the key to search for |
| 149 | |
| 150 | auto it = std::lower_bound(first: sorted_vocab.get(), last: sorted_vocab.get() + vocab_size, val: tok, comp: compare_tokens); |
| 151 | |
| 152 | // If we didn't reach the end and the string matches |
| 153 | if (it != (sorted_vocab.get() + vocab_size) && it->str == str) |
| 154 | { |
| 155 | return it->id; |
| 156 | } |
| 157 | |
| 158 | return -1; // Not found |
| 159 | } |
| 160 | |
| 161 | template<typename T> |
| 162 | class Transformer |
| 163 | { |
| 164 | private: |
| 165 | void malloc_weights(); |
| 166 | void malloc_run_state(); |
| 167 | |
| 168 | public: |
| 169 | Config config; |
| 170 | TransformerWeights<T> w; |
| 171 | RunState<T> s; |
| 172 | int shared_weights = 1; |
| 173 | void load_model(const std::string &checkpoint_path); |
| 174 | float *forward(int token, int pos); |
| 175 | }; |
| 176 | |
| 177 | class Tokenizer |
| 178 | { |
| 179 | public: |
| 180 | std::vector<std::unique_ptr<char[]>> vocab; |
| 181 | std::vector<float> vocab_scores; |
| 182 | std::unique_ptr<TokenIndex[]> sorted_vocab; |
| 183 | int vocab_size; |
| 184 | unsigned int max_token_length; |
| 185 | unsigned char byte_pieces[512]; // stores all single-byte strings |
| 186 | void build_tokenizer(const std::string &tokenizer_path, int size_for_vacab); |
| 187 | void encode(const std::string &text, const int8_t &bos, const int8_t &eos, std::unique_ptr<int[]> &tokens, int &n_tokens); |
| 188 | std::string decode(int prev_token, int token); |
| 189 | }; |
| 190 | |
| 191 | // ---------------------------------------------------------------------------- |
| 192 | // The Sampler, which takes logits and returns a sampled token |
| 193 | // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling |
| 194 | typedef struct |
| 195 | { |
| 196 | float prob; |
| 197 | int index; |
| 198 | } ProbIndex; // struct used when sorting probabilities during top-p sampling |
| 199 | |
| 200 | class Sampler |
| 201 | { |
| 202 | private: |
| 203 | int sample_argmax(float *probabilities, int n); |
| 204 | int sample_mult(float *probabilities, int n, float coin); |
| 205 | int sample_topp(float *probabilities, int n, float topp, std::unique_ptr<ProbIndex[]> &probindex, float coin); |
| 206 | unsigned int random_u32(unsigned long long *state) |
| 207 | { |
| 208 | // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A |
| 209 | *state ^= *state >> 12; |
| 210 | *state ^= *state << 25; |
| 211 | *state ^= *state >> 27; |
| 212 | return (*state * 0x2545F4914F6CDD1Dull) >> 32; |
| 213 | } |
| 214 | float random_f32(unsigned long long *state) |
| 215 | { // random float32 in [0,1) |
| 216 | return (random_u32(state) >> 8) / 16777216.0f; |
| 217 | } |
| 218 | static bool compare_probindex(const ProbIndex &a, const ProbIndex &b) |
| 219 | { |
| 220 | return a.prob > b.prob; |
| 221 | } |
| 222 | |
| 223 | public: |
| 224 | int vocab_size; |
| 225 | std::unique_ptr<ProbIndex[]> probindex; // buffer used in top-p sampling |
| 226 | float temperature; |
| 227 | float topp; |
| 228 | unsigned long long rng_state; |
| 229 | void build_sampler(int vocab_size, float temperature, float topp, unsigned long long rng_seed); |
| 230 | int sample(float *logits); |
| 231 | }; |
| 232 | |
| 233 | inline bool is_quantized_model(const std::string &checkpoint_path) |
| 234 | { |
| 235 | std::ifstream file(checkpoint_path, std::ios::binary); |
| 236 | if (!file) |
| 237 | { |
| 238 | std::cerr << "Couldn't open file " << checkpoint_path << '\n'; |
| 239 | std::exit(EXIT_FAILURE); |
| 240 | } |
| 241 | uint32_t magic_number; |
| 242 | int version; |
| 243 | |
| 244 | file.read(s: reinterpret_cast<char *>(&magic_number), n: sizeof(uint32_t)); |
| 245 | file.read(s: reinterpret_cast<char *>(&version), n: sizeof(int)); |
| 246 | |
| 247 | file.close(); |
| 248 | if (magic_number != 0x616b3432 || version != 2) |
| 249 | { |
| 250 | return false; |
| 251 | } |
| 252 | return true; |
| 253 | } |
| 254 | |
| 255 | inline void safe_print(const std::string &piece) |
| 256 | { |
| 257 | if (piece.empty()) |
| 258 | { |
| 259 | return; |
| 260 | } |
| 261 | if (piece.size() == 1) |
| 262 | { |
| 263 | unsigned char byte_val = piece[0]; |
| 264 | if (!(isprint(c: byte_val) || isspace(c: byte_val))) |
| 265 | { |
| 266 | return; // bad byte, don't print it |
| 267 | } |
| 268 | } |
| 269 | std::cout << piece; |
| 270 | } |
| 271 | |
| 272 | inline long time_in_ms() |
| 273 | { |
| 274 | // return time in milliseconds, for benchmarking the model speed |
| 275 | auto now = std::chrono::system_clock::now().time_since_epoch(); |
| 276 | return std::chrono::duration_cast<std::chrono::milliseconds>(d: now).count(); |
| 277 | } |
| 278 | |
| 279 | inline void softmax(float *x, int size) |
| 280 | { |
| 281 | // find max value (for numerical stability) |
| 282 | float max_val = x[0]; |
| 283 | for (int i = 1; i < size; i++) |
| 284 | { |
| 285 | if (x[i] > max_val) |
| 286 | { |
| 287 | max_val = x[i]; |
| 288 | } |
| 289 | } |
| 290 | // exp and sum |
| 291 | float sum = 0.0f; |
| 292 | for (int i = 0; i < size; i++) |
| 293 | { |
| 294 | x[i] = expf(x: x[i] - max_val); |
| 295 | sum += x[i]; |
| 296 | } |
| 297 | // normalize |
| 298 | for (int i = 0; i < size; i++) |
| 299 | { |
| 300 | x[i] /= sum; |
| 301 | } |
| 302 | } |
| 303 | |
| 304 | inline void rmsnorm(float *o, float *x, float *weight, int size) |
| 305 | { |
| 306 | // calculate sum of squares |
| 307 | float ss = 0.0f; |
| 308 | for (int j = 0; j < size; j++) |
| 309 | { |
| 310 | ss += x[j] * x[j]; |
| 311 | } |
| 312 | ss /= size; |
| 313 | ss += 1e-5f; |
| 314 | ss = 1.0f / sqrtf(x: ss); |
| 315 | // normalize and scale |
| 316 | for (int j = 0; j < size; j++) |
| 317 | { |
| 318 | o[j] = weight[j] * (ss * x[j]); |
| 319 | } |
| 320 | } |
| 321 | |
| 322 | inline void matmul(float *xout, float *x, float *w, int n, int d) |
| 323 | { |
| 324 | // W (d,n) @ x (n,) -> xout (d,) |
| 325 | // by far the most amount of time is spent inside this little function |
| 326 | int i; |
| 327 | for (i = 0; i < d; i++) |
| 328 | { |
| 329 | float val = 0.0f; |
| 330 | for (int j = 0; j < n; j++) |
| 331 | { |
| 332 | val += w[i * n + j] * x[j]; |
| 333 | } |
| 334 | xout[i] = val; |
| 335 | } |
| 336 | } |
| 337 | |
| 338 | inline void q_matmul(float *xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) |
| 339 | { |
| 340 | // W (d,n) @ x (n,) -> xout (d,) |
| 341 | // by far the most amount of time is spent inside this little function |
| 342 | // inputs to this function are both quantized |
| 343 | |
| 344 | int i; |
| 345 | for (i = 0; i < d; i++) |
| 346 | { |
| 347 | |
| 348 | float val = 0.0f; |
| 349 | int32_t ival = 0; |
| 350 | int in = i * n; |
| 351 | |
| 352 | // do the matmul in groups of GS |
| 353 | int j; |
| 354 | for (j = 0; j <= n - GS; j += GS) |
| 355 | { |
| 356 | for (int k = 0; k < GS; k++) |
| 357 | { |
| 358 | ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]); |
| 359 | } |
| 360 | val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS]; |
| 361 | ival = 0; |
| 362 | } |
| 363 | |
| 364 | xout[i] = val; |
| 365 | } |
| 366 | } |
| 367 | |
| 368 | inline void read_stdin(const std::string &guide, std::string &buffer, size_t max_len) |
| 369 | { |
| 370 | std::cout << guide; |
| 371 | std::getline(is&: std::cin, str&: buffer); |
| 372 | if (buffer.length() > max_len) |
| 373 | { |
| 374 | buffer.resize(n: max_len); |
| 375 | } |
| 376 | } |
| 377 | |
| 378 | inline void dequantize(QuantizedTensor *qx, float *x, int n) |
| 379 | { |
| 380 | for (int i = 0; i < n; i++) |
| 381 | { |
| 382 | x[i] = qx->q[i] * qx->s[i / GS]; |
| 383 | } |
| 384 | } |
| 385 | |
| 386 | inline void quantize(QuantizedTensor *qx, float *x, int n) |
| 387 | { |
| 388 | int num_groups = n / GS; |
| 389 | float Q_MAX = 127.0f; |
| 390 | |
| 391 | for (int group = 0; group < num_groups; group++) |
| 392 | { |
| 393 | // find the max absolute value in the current group |
| 394 | float wmax = 0.0; |
| 395 | for (int i = 0; i < GS; i++) |
| 396 | { |
| 397 | float val = fabs(x: x[group * GS + i]); |
| 398 | if (val > wmax) |
| 399 | { |
| 400 | wmax = val; |
| 401 | } |
| 402 | } |
| 403 | |
| 404 | // calculate and write the scaling factor |
| 405 | float scale = wmax / Q_MAX; |
| 406 | qx->s[group] = scale; |
| 407 | |
| 408 | // calculate and write the quantized values |
| 409 | for (int i = 0; i < GS; i++) |
| 410 | { |
| 411 | float quant_value = x[group * GS + i] / scale; // scale |
| 412 | int8_t quantized = (int8_t) round(x: quant_value); // round and clamp |
| 413 | qx->q[group * GS + i] = quantized; |
| 414 | } |
| 415 | } |
| 416 | } |
| 417 | |
| 418 | inline void init_quantized_tensors(std::ifstream &file, QuantizedTensor *w, int n_layers, int each_layer) |
| 419 | { |
| 420 | |
| 421 | for (int i = 0; i < n_layers; i++) |
| 422 | { |
| 423 | w[i].q = std::make_unique<int8_t[]>(num: each_layer); |
| 424 | w[i].s = std::make_unique<float[]>(num: each_layer / GS); |
| 425 | file.read(s: reinterpret_cast<char *>(w[i].q.get()), n: each_layer * sizeof(int8_t)); |
| 426 | file.read(s: reinterpret_cast<char *>(w[i].s.get()), n: each_layer / GS * sizeof(float)); |
| 427 | } |
| 428 | } |
| 429 | |
| 430 | #endif |
| 431 | |