1#ifndef RUN_H
2#define RUN_H
3#include <algorithm>
4#include <cctype>
5#include <chrono>
6#include <cmath>
7#include <cstddef>
8#include <cstdlib>
9#include <cstring>
10#include <ctime>
11#include <fcntl.h>
12#include <fstream>
13#include <iostream>
14#include <memory>
15#include <vector>
16
17#if defined _WIN32
18#include "win.h"
19#else
20#include <sys/mman.h>
21#include <unistd.h>
22#endif
23
24extern int GS;
25
26typedef struct
27{
28 std::unique_ptr<int8_t[]> q; // quantized values
29 std::unique_ptr<float[]> s; // scaling factors
30} QuantizedTensor;
31
32class Config
33{
34 public:
35 int dim; // transformer dimension
36 int hidden_dim; // for ffn layers
37 int n_layers; // number of layers
38 int n_heads; // number of query heads
39 int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
40 int vocab_size; // vocabulary size, usually 256 (byte-level)
41 int seq_len; // max sequence length
42};
43
44template<typename T>
45class TransformerWeights
46{
47 public:
48 // token embedding table
49 std::unique_ptr<float[]> token_embedding_table; // (vocab_size, dim)
50 // final rmsnorm
51 std::unique_ptr<float[]> rms_final_weight; // (dim,)
52 // (optional) classifier weights for the logits, on the last layer
53 // weights for rmsnorms
54 std::unique_ptr<float[]> rms_att_weight; // (layer, dim) rmsnorm weights
55 std::unique_ptr<float[]> rms_ffn_weight; // (layer, dim)
56 // weights for matmuls. note dim == n_heads * head_size
57 std::unique_ptr<T[]> wq; // (layer, dim, n_heads * head_size)
58 std::unique_ptr<T[]> wk; // (layer, dim, n_kv_heads * head_size)
59 std::unique_ptr<T[]> wv; // (layer, dim, n_kv_heads * head_size)
60 std::unique_ptr<T[]> wo; // (layer, n_heads * head_size, dim)
61 // weights for ffn
62 std::unique_ptr<T[]> w1; // (layer, hidden_dim, dim)
63 std::unique_ptr<T[]> w2; // (layer, dim, hidden_dim)
64 std::unique_ptr<T[]> w3; // (layer, hidden_dim, dim)
65 std::unique_ptr<T[]> wcls;
66 // tensor2d freq_cis_real; // [seq_len, (dim/n_heads)/2]
67 // tensor2d freq_cis_imag; // [seq_len, (dim/n_heads)/2]
68 std::unique_ptr<T[]> q_tokens; // (vocab_size, dim)
69};
70
71template<typename T>
72class RunState
73{
74 public:
75 // current wave of activations
76 std::unique_ptr<float[]> x; // activation at current time stamp (dim,)
77 std::unique_ptr<float[]> xb; // same, but inside a residual branch (dim,)
78 std::unique_ptr<float[]> xb2; // an additional buffer just for convenience (dim,)
79 std::unique_ptr<float[]> hb; // buffer for hidden dimension in the ffn (hidden_dim,)
80 std::unique_ptr<float[]> hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
81 std::unique_ptr<float[]> q; // query (dim,)
82 std::unique_ptr<float[]> k; // key (dim,)
83 std::unique_ptr<float[]> v; // value (dim,)
84 std::unique_ptr<float[]> att; // buffer for scores/attention values (n_heads, seq_len)
85 std::unique_ptr<float[]> logits; // output logits
86 // kv cache
87 std::unique_ptr<float[]> key_cache; // (layer, seq_len, dim)
88 std::unique_ptr<float[]> value_cache; // (layer, seq_len, dim)
89};
90
91template<>
92class RunState<float>
93{
94 public:
95 // current wave of activations
96 std::unique_ptr<float[]> x; // activation at current time stamp (dim,)
97 std::unique_ptr<float[]> xb; // same, but inside a residual branch (dim,)
98 std::unique_ptr<float[]> xb2; // an additional buffer just for convenience (dim,)
99 std::unique_ptr<float[]> hb; // buffer for hidden dimension in the ffn (hidden_dim,)
100 std::unique_ptr<float[]> hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
101 std::unique_ptr<float[]> q; // query (dim,)
102 std::unique_ptr<float[]> k; // key (dim,)
103 std::unique_ptr<float[]> v; // value (dim,)
104 std::unique_ptr<float[]> att; // buffer for scores/attention values (n_heads, seq_len)
105 std::unique_ptr<float[]> logits; // output logits
106 // kv cache
107 std::unique_ptr<float[]> key_cache; // (layer, seq_len, dim)
108 std::unique_ptr<float[]> value_cache; // (layer, seq_len, dim)
109};
110
111template<>
112class RunState<QuantizedTensor>
113{
114 public:
115 // current wave of activations
116 std::unique_ptr<float[]> x; // activation at current time stamp (dim,)
117 std::unique_ptr<float[]> xb; // same, but inside a residual branch (dim,)
118 std::unique_ptr<float[]> xb2; // an additional buffer just for convenience (dim,)
119 std::unique_ptr<float[]> hb; // buffer for hidden dimension in the ffn (hidden_dim,)
120 std::unique_ptr<float[]> hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
121 std::unique_ptr<float[]> q; // query (dim,)
122 std::unique_ptr<float[]> k; // key (dim,)
123 std::unique_ptr<float[]> v; // value (dim,)
124 std::unique_ptr<float[]> att; // buffer for scores/attention values (n_heads, seq_len)
125 std::unique_ptr<float[]> logits; // output logits
126 // kv cache
127 std::unique_ptr<float[]> key_cache; // (layer, seq_len, dim)
128 std::unique_ptr<float[]> value_cache; // (layer, seq_len, dim)
129
130 std::unique_ptr<QuantizedTensor[]> xq; // quantized x (dim,)
131 std::unique_ptr<QuantizedTensor[]> hq; // quantized hb (hidden_dim,)
132};
133
134typedef struct
135{
136 std::string str;
137 int id;
138} TokenIndex;
139
140inline bool compare_tokens(const TokenIndex &a, const TokenIndex &b)
141{
142 return a.str < b.str;
143}
144
145inline int str_lookup(const std::string &str, const std::unique_ptr<TokenIndex[]> &sorted_vocab, int vocab_size)
146{
147 // efficiently find the perfect match for str in vocab, return its index or -1 if not found
148 TokenIndex tok = { .str = str }; // acts as the key to search for
149
150 auto it = std::lower_bound(first: sorted_vocab.get(), last: sorted_vocab.get() + vocab_size, val: tok, comp: compare_tokens);
151
152 // If we didn't reach the end and the string matches
153 if (it != (sorted_vocab.get() + vocab_size) && it->str == str)
154 {
155 return it->id;
156 }
157
158 return -1; // Not found
159}
160
161template<typename T>
162class Transformer
163{
164 private:
165 void malloc_weights();
166 void malloc_run_state();
167
168 public:
169 Config config;
170 TransformerWeights<T> w;
171 RunState<T> s;
172 int shared_weights = 1;
173 void load_model(const std::string &checkpoint_path);
174 float *forward(int token, int pos);
175};
176
177class Tokenizer
178{
179 public:
180 std::vector<std::unique_ptr<char[]>> vocab;
181 std::vector<float> vocab_scores;
182 std::unique_ptr<TokenIndex[]> sorted_vocab;
183 int vocab_size;
184 unsigned int max_token_length;
185 unsigned char byte_pieces[512]; // stores all single-byte strings
186 void build_tokenizer(const std::string &tokenizer_path, int size_for_vacab);
187 void encode(const std::string &text, const int8_t &bos, const int8_t &eos, std::unique_ptr<int[]> &tokens, int &n_tokens);
188 std::string decode(int prev_token, int token);
189};
190
191// ----------------------------------------------------------------------------
192// The Sampler, which takes logits and returns a sampled token
193// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
194typedef struct
195{
196 float prob;
197 int index;
198} ProbIndex; // struct used when sorting probabilities during top-p sampling
199
200class Sampler
201{
202 private:
203 int sample_argmax(float *probabilities, int n);
204 int sample_mult(float *probabilities, int n, float coin);
205 int sample_topp(float *probabilities, int n, float topp, std::unique_ptr<ProbIndex[]> &probindex, float coin);
206 unsigned int random_u32(unsigned long long *state)
207 {
208 // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
209 *state ^= *state >> 12;
210 *state ^= *state << 25;
211 *state ^= *state >> 27;
212 return (*state * 0x2545F4914F6CDD1Dull) >> 32;
213 }
214 float random_f32(unsigned long long *state)
215 { // random float32 in [0,1)
216 return (random_u32(state) >> 8) / 16777216.0f;
217 }
218 static bool compare_probindex(const ProbIndex &a, const ProbIndex &b)
219 {
220 return a.prob > b.prob;
221 }
222
223 public:
224 int vocab_size;
225 std::unique_ptr<ProbIndex[]> probindex; // buffer used in top-p sampling
226 float temperature;
227 float topp;
228 unsigned long long rng_state;
229 void build_sampler(int vocab_size, float temperature, float topp, unsigned long long rng_seed);
230 int sample(float *logits);
231};
232
233inline bool is_quantized_model(const std::string &checkpoint_path)
234{
235 std::ifstream file(checkpoint_path, std::ios::binary);
236 if (!file)
237 {
238 std::cerr << "Couldn't open file " << checkpoint_path << '\n';
239 std::exit(EXIT_FAILURE);
240 }
241 uint32_t magic_number;
242 int version;
243
244 file.read(s: reinterpret_cast<char *>(&magic_number), n: sizeof(uint32_t));
245 file.read(s: reinterpret_cast<char *>(&version), n: sizeof(int));
246
247 file.close();
248 if (magic_number != 0x616b3432 || version != 2)
249 {
250 return false;
251 }
252 return true;
253}
254
255inline void safe_print(const std::string &piece)
256{
257 if (piece.empty())
258 {
259 return;
260 }
261 if (piece.size() == 1)
262 {
263 unsigned char byte_val = piece[0];
264 if (!(isprint(c: byte_val) || isspace(c: byte_val)))
265 {
266 return; // bad byte, don't print it
267 }
268 }
269 std::cout << piece;
270}
271
272inline long time_in_ms()
273{
274 // return time in milliseconds, for benchmarking the model speed
275 auto now = std::chrono::system_clock::now().time_since_epoch();
276 return std::chrono::duration_cast<std::chrono::milliseconds>(d: now).count();
277}
278
279inline void softmax(float *x, int size)
280{
281 // find max value (for numerical stability)
282 float max_val = x[0];
283 for (int i = 1; i < size; i++)
284 {
285 if (x[i] > max_val)
286 {
287 max_val = x[i];
288 }
289 }
290 // exp and sum
291 float sum = 0.0f;
292 for (int i = 0; i < size; i++)
293 {
294 x[i] = expf(x: x[i] - max_val);
295 sum += x[i];
296 }
297 // normalize
298 for (int i = 0; i < size; i++)
299 {
300 x[i] /= sum;
301 }
302}
303
304inline void rmsnorm(float *o, float *x, float *weight, int size)
305{
306 // calculate sum of squares
307 float ss = 0.0f;
308 for (int j = 0; j < size; j++)
309 {
310 ss += x[j] * x[j];
311 }
312 ss /= size;
313 ss += 1e-5f;
314 ss = 1.0f / sqrtf(x: ss);
315 // normalize and scale
316 for (int j = 0; j < size; j++)
317 {
318 o[j] = weight[j] * (ss * x[j]);
319 }
320}
321
322inline void matmul(float *xout, float *x, float *w, int n, int d)
323{
324 // W (d,n) @ x (n,) -> xout (d,)
325 // by far the most amount of time is spent inside this little function
326 int i;
327 for (i = 0; i < d; i++)
328 {
329 float val = 0.0f;
330 for (int j = 0; j < n; j++)
331 {
332 val += w[i * n + j] * x[j];
333 }
334 xout[i] = val;
335 }
336}
337
338inline void q_matmul(float *xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d)
339{
340 // W (d,n) @ x (n,) -> xout (d,)
341 // by far the most amount of time is spent inside this little function
342 // inputs to this function are both quantized
343
344 int i;
345 for (i = 0; i < d; i++)
346 {
347
348 float val = 0.0f;
349 int32_t ival = 0;
350 int in = i * n;
351
352 // do the matmul in groups of GS
353 int j;
354 for (j = 0; j <= n - GS; j += GS)
355 {
356 for (int k = 0; k < GS; k++)
357 {
358 ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
359 }
360 val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
361 ival = 0;
362 }
363
364 xout[i] = val;
365 }
366}
367
368inline void read_stdin(const std::string &guide, std::string &buffer, size_t max_len)
369{
370 std::cout << guide;
371 std::getline(is&: std::cin, str&: buffer);
372 if (buffer.length() > max_len)
373 {
374 buffer.resize(n: max_len);
375 }
376}
377
378inline void dequantize(QuantizedTensor *qx, float *x, int n)
379{
380 for (int i = 0; i < n; i++)
381 {
382 x[i] = qx->q[i] * qx->s[i / GS];
383 }
384}
385
386inline void quantize(QuantizedTensor *qx, float *x, int n)
387{
388 int num_groups = n / GS;
389 float Q_MAX = 127.0f;
390
391 for (int group = 0; group < num_groups; group++)
392 {
393 // find the max absolute value in the current group
394 float wmax = 0.0;
395 for (int i = 0; i < GS; i++)
396 {
397 float val = fabs(x: x[group * GS + i]);
398 if (val > wmax)
399 {
400 wmax = val;
401 }
402 }
403
404 // calculate and write the scaling factor
405 float scale = wmax / Q_MAX;
406 qx->s[group] = scale;
407
408 // calculate and write the quantized values
409 for (int i = 0; i < GS; i++)
410 {
411 float quant_value = x[group * GS + i] / scale; // scale
412 int8_t quantized = (int8_t) round(x: quant_value); // round and clamp
413 qx->q[group * GS + i] = quantized;
414 }
415 }
416}
417
418inline void init_quantized_tensors(std::ifstream &file, QuantizedTensor *w, int n_layers, int each_layer)
419{
420
421 for (int i = 0; i < n_layers; i++)
422 {
423 w[i].q = std::make_unique<int8_t[]>(num: each_layer);
424 w[i].s = std::make_unique<float[]>(num: each_layer / GS);
425 file.read(s: reinterpret_cast<char *>(w[i].q.get()), n: each_layer * sizeof(int8_t));
426 file.read(s: reinterpret_cast<char *>(w[i].s.get()), n: each_layer / GS * sizeof(float));
427 }
428}
429
430#endif
431