1#include "run.h"
2int GS = 0;
3
4void error_usage()
5{
6 std::cerr << R"(Usage: run <checkpoint> [options]
7Example: run model.bin -n 256 -i "Once upon a time"
8Options:
9 -t <float> temperature in [0,inf], default 1.0
10 -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9
11 -s <int> random seed, default time(NULL)
12 -n <int> number of steps to run for, default 256. 0 = max_seq_len
13 -i <string> input prompt
14 -z <string> optional path to custom tokenizer
15 -m <string> mode: generate|chat, default: generate
16 -y <string> (optional) system prompt in chat mode
17)";
18 std::exit(EXIT_FAILURE);
19}
20
21template<>
22void Transformer<float>::malloc_weights()
23{
24 int head_size = config.dim / config.n_heads;
25 // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
26 unsigned long long n_layers = config.n_layers;
27 w.token_embedding_table = std::make_unique<float[]>(num: config.vocab_size * config.dim);
28 w.rms_att_weight = std::make_unique<float[]>(num: n_layers * config.dim);
29 w.wq = std::make_unique<float[]>(num: n_layers * config.dim * config.n_heads * head_size);
30 w.wk = std::make_unique<float[]>(num: n_layers * config.dim * config.n_kv_heads * head_size);
31 w.wv = std::make_unique<float[]>(num: n_layers * config.dim * config.n_kv_heads * head_size);
32 w.wo = std::make_unique<float[]>(num: n_layers * config.dim * config.n_heads * head_size);
33 w.rms_ffn_weight = std::make_unique<float[]>(num: n_layers * config.dim);
34 w.w1 = std::make_unique<float[]>(num: n_layers * config.dim * config.hidden_dim);
35 w.w2 = std::make_unique<float[]>(num: n_layers * config.dim * config.hidden_dim);
36 w.w3 = std::make_unique<float[]>(num: n_layers * config.dim * config.hidden_dim);
37 w.rms_final_weight = std::make_unique<float[]>(num: config.dim);
38 if (!shared_weights)
39 {
40 w.wcls = std::make_unique<float[]>(num: config.vocab_size * config.dim);
41 if (!w.wcls.get())
42 {
43 std::cerr << "Malloc for wcls weights failed.\n";
44 std::exit(EXIT_FAILURE);
45 }
46 }
47 if (!w.token_embedding_table.get() || !w.rms_att_weight.get() || !w.wq.get() || !w.wk.get() || !w.wv.get() || !w.wo.get() || !w.rms_ffn_weight.get() || !w.w1.get() ||
48 !w.w2.get() || !w.w3.get() || !w.rms_final_weight.get())
49 {
50 std::cerr << "Malloc for weights failed.\n";
51 std::exit(EXIT_FAILURE);
52 }
53}
54
55template<>
56void Transformer<QuantizedTensor>::malloc_weights()
57{
58 // int head_size = config.dim / config.n_heads;
59 // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
60 unsigned long long n_layers = config.n_layers;
61 w.token_embedding_table = std::make_unique<float[]>(num: config.vocab_size * config.dim);
62 w.rms_att_weight = std::make_unique<float[]>(num: n_layers * config.dim);
63 w.wq = std::make_unique<QuantizedTensor[]>(num: n_layers);
64 w.wk = std::make_unique<QuantizedTensor[]>(num: n_layers);
65 w.wv = std::make_unique<QuantizedTensor[]>(num: n_layers);
66 w.wo = std::make_unique<QuantizedTensor[]>(num: n_layers);
67 w.rms_ffn_weight = std::make_unique<float[]>(num: n_layers * config.dim);
68 w.w1 = std::make_unique<QuantizedTensor[]>(num: n_layers);
69 w.w2 = std::make_unique<QuantizedTensor[]>(num: n_layers);
70 w.w3 = std::make_unique<QuantizedTensor[]>(num: n_layers);
71 w.rms_final_weight = std::make_unique<float[]>(num: config.dim);
72
73 w.q_tokens = std::make_unique<QuantizedTensor[]>(num: 1);
74
75 if (!shared_weights)
76 {
77 w.wcls = std::make_unique<QuantizedTensor[]>(num: 1);
78 if (!w.wcls.get())
79 {
80 std::cerr << "Malloc for wcls weights failed.\n";
81 std::exit(EXIT_FAILURE);
82 }
83 }
84 if (!w.token_embedding_table.get() || !w.rms_att_weight.get() || !w.wq.get() || !w.q_tokens.get() || !w.wk.get() || !w.wv.get() || !w.wo.get() ||
85 !w.rms_ffn_weight.get() || !w.w1.get() || !w.w2.get() || !w.w3.get() || !w.rms_final_weight.get())
86 {
87 std::cerr << "Malloc for weights failed.\n";
88 std::exit(EXIT_FAILURE);
89 }
90}
91
92template<>
93void Transformer<float>::malloc_run_state()
94{
95 int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads;
96 s.x = std::make_unique<float[]>(num: config.dim);
97 s.xb = std::make_unique<float[]>(num: config.dim);
98 s.xb2 = std::make_unique<float[]>(num: config.dim);
99 s.hb = std::make_unique<float[]>(num: config.hidden_dim);
100 s.hb2 = std::make_unique<float[]>(num: config.hidden_dim);
101 s.q = std::make_unique<float[]>(num: config.dim);
102 s.k = std::make_unique<float[]>(num: kv_dim);
103 s.v = std::make_unique<float[]>(num: kv_dim);
104 s.att = std::make_unique<float[]>(num: config.seq_len * config.n_heads);
105 s.logits = std::make_unique<float[]>(num: config.vocab_size);
106 s.key_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim);
107 s.value_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim);
108 if (!s.x.get() || !s.xb.get() || !s.xb2.get() || !s.hb.get() || !s.hb2.get() || !s.q.get() || !s.k.get() || !s.v.get() || !s.att.get() || !s.logits.get() ||
109 !s.key_cache.get() || !s.value_cache.get())
110 {
111 std::cerr << "Malloc for run state failed.\n";
112 std::exit(EXIT_FAILURE);
113 }
114}
115
116template<>
117void Transformer<QuantizedTensor>::malloc_run_state()
118{
119 int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads;
120 s.x = std::make_unique<float[]>(num: config.dim);
121 s.xb = std::make_unique<float[]>(num: config.dim);
122 s.xb2 = std::make_unique<float[]>(num: config.dim);
123 s.hb = std::make_unique<float[]>(num: config.hidden_dim);
124 s.hb2 = std::make_unique<float[]>(num: config.hidden_dim);
125 s.q = std::make_unique<float[]>(num: config.dim);
126 s.k = std::make_unique<float[]>(num: kv_dim);
127 s.v = std::make_unique<float[]>(num: kv_dim);
128 s.att = std::make_unique<float[]>(num: config.seq_len * config.n_heads);
129 s.logits = std::make_unique<float[]>(num: config.vocab_size);
130 s.key_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim);
131 s.value_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim);
132 if (!s.x.get() || !s.xb.get() || !s.xb2.get() || !s.hb.get() || !s.hb2.get() || !s.q.get() || !s.k.get() || !s.v.get() || !s.att.get() || !s.logits.get() ||
133 !s.key_cache.get() || !s.value_cache.get())
134 {
135 std::cerr << "Malloc for run state failed.\n";
136 std::exit(EXIT_FAILURE);
137 }
138
139 s.xq = std::make_unique<QuantizedTensor[]>(num: 1);
140 s.hq = std::make_unique<QuantizedTensor[]>(num: 1);
141 if (!s.xq.get() || !s.hq.get())
142 {
143 std::cerr << "Malloc for run state xq or hq failed.\n";
144 std::exit(EXIT_FAILURE);
145 }
146 s.xq[0].q = std::make_unique<int8_t[]>(num: config.dim);
147 s.xq[0].s = std::make_unique<float[]>(num: config.dim);
148 if (!s.xq[0].q.get() || !s.xq[0].s.get())
149 {
150 std::cerr << "Malloc for run state xq[0] failed.\n";
151 std::exit(EXIT_FAILURE);
152 }
153 s.hq[0].q = std::make_unique<int8_t[]>(num: config.hidden_dim);
154 s.hq[0].s = std::make_unique<float[]>(num: config.hidden_dim);
155 if (!s.hq[0].q.get() || !s.hq[0].s.get())
156 {
157 std::cerr << "Malloc for run state hq[0] failed.\n";
158 std::exit(EXIT_FAILURE);
159 }
160}
161
162template<>
163void Transformer<float>::load_model(const std::string &checkpoint_path)
164{
165 std::ifstream file(checkpoint_path, std::ios::binary);
166 if (!file)
167 {
168 std::cerr << "Couldn't open file " << checkpoint_path << '\n';
169 std::exit(EXIT_FAILURE);
170 }
171 // 60816028 bytes
172 file.read(s: reinterpret_cast<char *>(&config), n: sizeof(Config));
173 shared_weights = config.vocab_size > 0 ? 1 : 0;
174 config.vocab_size = std::abs(number: config.vocab_size);
175 malloc_weights();
176 int head_size = config.dim / config.n_heads;
177 unsigned long long n_layers = config.n_layers;
178
179 file.read(s: reinterpret_cast<char *>(w.token_embedding_table.get()), n: config.vocab_size * config.dim * sizeof(float));
180 file.read(s: reinterpret_cast<char *>(w.rms_att_weight.get()), n: config.n_layers * config.dim * sizeof(float));
181 file.read(s: reinterpret_cast<char *>(w.wq.get()), n: n_layers * config.dim * config.n_heads * head_size * sizeof(float));
182 file.read(s: reinterpret_cast<char *>(w.wk.get()), n: n_layers * config.dim * config.n_kv_heads * head_size * sizeof(float));
183 file.read(s: reinterpret_cast<char *>(w.wv.get()), n: n_layers * config.dim * config.n_kv_heads * head_size * sizeof(float));
184 file.read(s: reinterpret_cast<char *>(w.wo.get()), n: n_layers * config.dim * config.n_heads * head_size * sizeof(float));
185 file.read(s: reinterpret_cast<char *>(w.rms_ffn_weight.get()), n: n_layers * config.dim * sizeof(float));
186 file.read(s: reinterpret_cast<char *>(w.w1.get()), n: n_layers * config.dim * config.hidden_dim * sizeof(float));
187 file.read(s: reinterpret_cast<char *>(w.w2.get()), n: n_layers * config.dim * config.hidden_dim * sizeof(float));
188 file.read(s: reinterpret_cast<char *>(w.w3.get()), n: n_layers * config.dim * config.hidden_dim * sizeof(float));
189 file.read(s: reinterpret_cast<char *>(w.rms_final_weight.get()), n: config.dim * sizeof(float));
190
191 if (!shared_weights)
192 {
193 file.seekg((config.seq_len * head_size) * sizeof(float), std::ios::cur);
194 file.read(s: reinterpret_cast<char *>(w.wcls.get()), n: config.vocab_size * config.dim * sizeof(float));
195 }
196 file.close();
197 malloc_run_state();
198}
199
200template<>
201void Transformer<QuantizedTensor>::load_model(const std::string &checkpoint_path)
202{
203
204 std::ifstream file(checkpoint_path, std::ios::binary);
205 if (!file)
206 {
207 std::cerr << "Couldn't open file " << checkpoint_path << '\n';
208 std::exit(EXIT_FAILURE);
209 }
210
211 uint32_t magic_number;
212 file.read(s: reinterpret_cast<char *>(&magic_number), n: sizeof(uint32_t));
213 if (magic_number != 0x616b3432)
214 {
215 std::cerr << "Bad magic number\n";
216 std::exit(EXIT_FAILURE);
217 }
218
219 int version;
220 file.read(s: reinterpret_cast<char *>(&version), n: sizeof(int));
221 if (version != 2)
222 {
223 std::cerr << "Bad version " << version << ", need version 2\n";
224 std::exit(EXIT_FAILURE);
225 }
226
227 file.read(s: reinterpret_cast<char *>(&config), n: sizeof(Config));
228
229 // read in flags
230 uint8_t shared_classifier; // a byte to indicate if the classifier is shared
231 file.read(s: reinterpret_cast<char *>(&shared_classifier), n: sizeof(uint8_t));
232
233 int group_size;
234 file.read(s: reinterpret_cast<char *>(&group_size), n: sizeof(int));
235 GS = group_size;
236
237 shared_weights = shared_classifier;
238 // config.vocab_size = std::abs(config.vocab_size);
239
240 malloc_weights();
241 int head_size = config.dim / config.n_heads;
242 unsigned long long n_layers = config.n_layers;
243
244 int header_size = 256;
245 file.seekg(header_size, std::ios::beg);
246 file.read(s: reinterpret_cast<char *>(w.rms_att_weight.get()), n: config.n_layers * config.dim * sizeof(float));
247 file.read(s: reinterpret_cast<char *>(w.rms_ffn_weight.get()), n: n_layers * config.dim * sizeof(float));
248 file.read(s: reinterpret_cast<char *>(w.rms_final_weight.get()), n: config.dim * sizeof(float));
249 init_quantized_tensors(file, w: w.q_tokens.get(), n_layers: 1, each_layer: config.vocab_size * config.dim);
250 dequantize(qx: w.q_tokens.get(), x: w.token_embedding_table.get(), n: config.vocab_size * config.dim);
251
252 init_quantized_tensors(file, w: w.wq.get(), n_layers, each_layer: config.dim * config.n_heads * head_size);
253 init_quantized_tensors(file, w: w.wk.get(), n_layers, each_layer: config.dim * config.n_kv_heads * head_size);
254 init_quantized_tensors(file, w: w.wv.get(), n_layers, each_layer: config.dim * config.n_kv_heads * head_size);
255 init_quantized_tensors(file, w: w.wo.get(), n_layers, each_layer: config.dim * config.n_heads * head_size);
256
257 init_quantized_tensors(file, w: w.w1.get(), n_layers, each_layer: config.dim * config.hidden_dim);
258 init_quantized_tensors(file, w: w.w2.get(), n_layers, each_layer: config.dim * config.hidden_dim);
259 init_quantized_tensors(file, w: w.w3.get(), n_layers, each_layer: config.dim * config.hidden_dim);
260
261 if (!shared_weights)
262 {
263 init_quantized_tensors(file, w: w.wcls.get(), n_layers: 1, each_layer: config.dim * config.vocab_size);
264 }
265 file.close();
266 malloc_run_state();
267}
268
269template<>
270float *Transformer<float>::forward(int token, int pos)
271{
272 int dim = config.dim;
273 int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads;
274 int kv_mul = config.n_heads / config.n_kv_heads; // integer multiplier of the kv sharing in multiquery
275 int hidden_dim = config.hidden_dim;
276 int head_size = dim / config.n_heads;
277 // copy the token embedding into x
278 std::memcpy(dest: s.x.get(), src: w.token_embedding_table.get() + token * dim, size: dim * sizeof(*(s.x.get())));
279 // forward all the layers
280 for (decltype(config.n_layers) l = 0; l < config.n_layers; l++)
281 {
282 // attention rmsnorm
283 rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_att_weight.get() + l * dim, size: dim);
284
285 // qkv matmuls for this position
286 matmul(xout: s.q.get(), x: s.xb.get(), w: w.wq.get() + l * dim * dim, n: dim, d: dim);
287 matmul(xout: s.k.get(), x: s.xb.get(), w: w.wk.get() + l * dim * kv_dim, n: dim, d: kv_dim);
288 matmul(xout: s.v.get(), x: s.xb.get(), w: w.wv.get() + l * dim * kv_dim, n: dim, d: kv_dim);
289
290 // RoPE relative positional encoding: complex-valued rotate q and k in each head
291 for (int i = 0; i < dim; i += 2)
292 {
293 int head_dim = i % head_size;
294 float freq = 1.0f / powf(x: 10000.0f, y: head_dim / (float) head_size);
295 float val = pos * freq;
296 float fcr = cosf(x: val);
297 float fci = sinf(x: val);
298 int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
299 for (int v = 0; v < rotn; v++)
300 {
301 float *vec = v == 0 ? s.q.get() : s.k.get(); // the vector to rotate (query or key)
302 float v0 = vec[i];
303 float v1 = vec[i + 1];
304 vec[i] = v0 * fcr - v1 * fci;
305 vec[i + 1] = v0 * fci + v1 * fcr;
306 }
307 }
308
309 // save key,value at this time step (pos) to our kv cache
310 int loff = l * config.seq_len * kv_dim; // kv cache layer offset for convenience
311 float *key_cache_row = s.key_cache.get() + loff + pos * kv_dim;
312 float *value_cache_row = s.value_cache.get() + loff + pos * kv_dim;
313 std::memcpy(dest: key_cache_row, src: s.k.get(), size: kv_dim * sizeof(*key_cache_row));
314 std::memcpy(dest: value_cache_row, src: s.v.get(), size: kv_dim * sizeof(*value_cache_row));
315
316 // multihead attention. iterate over all heads
317 int h;
318 for (h = 0; h < config.n_heads; h++)
319 {
320 // get the query vector for this head
321 float *q = s.q.get() + h * head_size;
322 // attention scores for this head
323 float *att = s.att.get() + h * config.seq_len;
324 // iterate over all timesteps, including the current one
325 for (int t = 0; t <= pos; t++)
326 {
327 // get the key vector for this head and at this timestep
328 float *k = s.key_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size;
329 // calculate the attention score as the dot product of q and k
330 float score = 0.0f;
331 for (int i = 0; i < head_size; i++)
332 {
333 // q.shape = (1,head_size) k.shape= (head_size, n_head, seq_len)
334 score += q[i] * k[i];
335 }
336 score /= sqrtf(x: head_size);
337 // save the score to the attention buffer
338 // att.shape = (n_heads, seq_len)
339 att[t] = score;
340 }
341
342 // softmax the scores to get attention weights, from 0..pos inclusively
343 softmax(x: att, size: pos + 1);
344
345 // weighted sum of the values, store back into xb
346 float *xb = s.xb.get() + h * head_size;
347 std::memset(dest: xb, c: 0, size: head_size * sizeof(float));
348 for (int t = 0; t <= pos; t++)
349 {
350 // get the value vector for this head and at this timestep
351 float *v = s.value_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size;
352 // get the attention weight for this timestep
353 float a = att[t];
354 // accumulate the weighted value into xb
355 for (int i = 0; i < head_size; i++)
356 {
357 xb[i] += a * v[i];
358 }
359 }
360 }
361
362 // final matmul to get the output of the attention
363 matmul(xout: s.xb2.get(), x: s.xb.get(), w: w.wo.get() + l * dim * dim, n: dim, d: dim);
364
365 // residual connection back into x
366 for (int i = 0; i < dim; i++)
367 {
368 s.x[i] += s.xb2[i];
369 }
370
371 // ffn rmsnorm
372 rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_ffn_weight.get() + l * dim, size: dim);
373
374 // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
375 // first calculate self.w1(x) and self.w3(x)
376 matmul(xout: s.hb.get(), x: s.xb.get(), w: w.w1.get() + l * dim * hidden_dim, n: dim, d: hidden_dim);
377 matmul(xout: s.hb2.get(), x: s.xb.get(), w: w.w3.get() + l * dim * hidden_dim, n: dim, d: hidden_dim);
378
379 // SwiGLU non-linearity
380 for (int i = 0; i < hidden_dim; i++)
381 {
382 float val = s.hb[i];
383 // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
384 val *= (1.0f / (1.0f + expf(x: -val)));
385 // elementwise multiply with w3(x)
386 val *= s.hb2[i];
387 s.hb[i] = val;
388 }
389
390 // final matmul to get the output of the ffn
391 matmul(xout: s.xb.get(), x: s.hb.get(), w: w.w2.get() + l * dim * hidden_dim, n: hidden_dim, d: dim);
392
393 // residual connection
394 for (int i = 0; i < dim; i++)
395 {
396 s.x[i] += s.xb[i];
397 }
398 }
399 // final rmsnorm
400 rmsnorm(o: s.x.get(), x: s.x.get(), weight: w.rms_final_weight.get(), size: dim);
401 // classifier into logits
402
403 if (shared_weights)
404 {
405 // w.wcls = std::move(w.token_embedding_table);
406 matmul(xout: s.logits.get(), x: s.x.get(), w: w.token_embedding_table.get(), n: config.dim, d: config.vocab_size);
407 }
408 else
409 {
410 matmul(xout: s.logits.get(), x: s.x.get(), w: w.wcls.get(), n: config.dim, d: config.vocab_size);
411 }
412 return s.logits.get();
413}
414
415template<>
416float *Transformer<QuantizedTensor>::forward(int token, int pos)
417{
418 int dim = config.dim;
419 int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads;
420 int kv_mul = config.n_heads / config.n_kv_heads; // integer multiplier of the kv sharing in multiquery
421 int hidden_dim = config.hidden_dim;
422 int head_size = dim / config.n_heads;
423 // copy the token embedding into x
424 std::memcpy(dest: s.x.get(), src: w.token_embedding_table.get() + token * dim, size: dim * sizeof(float));
425 // forward all the layers
426 for (decltype(config.n_layers) l = 0; l < config.n_layers; l++)
427 {
428 // attention rmsnorm
429 rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_att_weight.get() + l * dim, size: dim);
430
431 // qkv matmuls for this position
432 quantize(qx: s.xq.get(), x: s.xb.get(), n: dim);
433 q_matmul(xout: s.q.get(), x: s.xq.get(), w: w.wq.get() + l, n: dim, d: dim);
434 q_matmul(xout: s.k.get(), x: s.xq.get(), w: w.wk.get() + l, n: dim, d: kv_dim);
435 q_matmul(xout: s.v.get(), x: s.xq.get(), w: w.wv.get() + l, n: dim, d: kv_dim);
436
437 // RoPE relative positional encoding: complex-valued rotate q and k in each head
438 for (int i = 0; i < dim; i += 2)
439 {
440 int head_dim = i % head_size;
441 float freq = 1.0f / powf(x: 10000.0f, y: head_dim / (float) head_size);
442 float val = pos * freq;
443 float fcr = cosf(x: val);
444 float fci = sinf(x: val);
445 int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
446 for (int v = 0; v < rotn; v++)
447 {
448 float *vec = v == 0 ? s.q.get() : s.k.get(); // the vector to rotate (query or key)
449 float v0 = vec[i];
450 float v1 = vec[i + 1];
451 vec[i] = v0 * fcr - v1 * fci;
452 vec[i + 1] = v0 * fci + v1 * fcr;
453 }
454 }
455
456 // save key,value at this time step (pos) to our kv cache
457 int loff = l * config.seq_len * kv_dim; // kv cache layer offset for convenience
458 float *key_cache_row = s.key_cache.get() + loff + pos * kv_dim;
459 float *value_cache_row = s.value_cache.get() + loff + pos * kv_dim;
460 std::memcpy(dest: key_cache_row, src: s.k.get(), size: kv_dim * sizeof(*key_cache_row));
461 std::memcpy(dest: value_cache_row, src: s.v.get(), size: kv_dim * sizeof(*value_cache_row));
462
463 // multihead attention. iterate over all heads
464 int h;
465 for (h = 0; h < config.n_heads; h++)
466 {
467 // get the query vector for this head
468 float *q = s.q.get() + h * head_size;
469 // attention scores for this head
470 float *att = s.att.get() + h * config.seq_len;
471 // iterate over all timesteps, including the current one
472 for (int t = 0; t <= pos; t++)
473 {
474 // get the key vector for this head and at this timestep
475 float *k = s.key_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size;
476 // calculate the attention score as the dot product of q and k
477 float score = 0.0f;
478 for (int i = 0; i < head_size; i++)
479 {
480 // q.shape = (1,head_size) k.shape= (head_size, n_head, seq_len)
481 score += q[i] * k[i];
482 }
483 score /= sqrtf(x: head_size);
484 // save the score to the attention buffer
485 // att.shape = (n_heads, seq_len)
486 att[t] = score;
487 }
488
489 // softmax the scores to get attention weights, from 0..pos inclusively
490 softmax(x: att, size: pos + 1);
491
492 // weighted sum of the values, store back into xb
493 float *xb = s.xb.get() + h * head_size;
494 std::memset(dest: xb, c: 0, size: head_size * sizeof(float));
495 for (int t = 0; t <= pos; t++)
496 {
497 // get the value vector for this head and at this timestep
498 float *v = s.value_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size;
499 // get the attention weight for this timestep
500 float a = att[t];
501 // accumulate the weighted value into xb
502 for (int i = 0; i < head_size; i++)
503 {
504 xb[i] += a * v[i];
505 }
506 }
507 }
508
509 quantize(qx: s.xq.get(), x: s.xb.get(), n: dim);
510 // final matmul to get the output of the attention
511 q_matmul(xout: s.xb2.get(), x: s.xq.get(), w: w.wo.get() + l, n: dim, d: dim);
512
513 // residual connection back into x
514 for (int i = 0; i < dim; i++)
515 {
516 s.x[i] += s.xb2[i];
517 }
518
519 // ffn rmsnorm
520 rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_ffn_weight.get() + l * dim, size: dim);
521
522 // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
523 // first calculate self.w1(x) and self.w3(x)
524 quantize(qx: s.xq.get(), x: s.xb.get(), n: dim);
525 q_matmul(xout: s.hb.get(), x: s.xq.get(), w: w.w1.get() + l, n: dim, d: hidden_dim);
526 q_matmul(xout: s.hb2.get(), x: s.xq.get(), w: w.w3.get() + l, n: dim, d: hidden_dim);
527
528 // SwiGLU non-linearity
529 for (int i = 0; i < hidden_dim; i++)
530 {
531 float val = s.hb[i];
532 // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
533 val *= (1.0f / (1.0f + expf(x: -val)));
534 // elementwise multiply with w3(x)
535 val *= s.hb2[i];
536 s.hb[i] = val;
537 }
538
539 // final matmul to get the output of the ffn
540 quantize(qx: s.hq.get(), x: s.hb.get(), n: hidden_dim);
541 q_matmul(xout: s.xb.get(), x: s.hq.get(), w: w.w2.get() + l, n: hidden_dim, d: dim);
542
543 // residual connection
544 for (int i = 0; i < dim; i++)
545 {
546 s.x[i] += s.xb[i];
547 }
548 }
549 // final rmsnorm
550 rmsnorm(o: s.x.get(), x: s.x.get(), weight: w.rms_final_weight.get(), size: dim);
551 // classifier into logits
552
553 quantize(qx: s.xq.get(), x: s.x.get(), n: dim);
554 if (shared_weights)
555 {
556 // w.wcls = std::move(w.token_embedding_table);
557 q_matmul(xout: s.logits.get(), x: s.xq.get(), w: w.q_tokens.get(), n: config.dim, d: config.vocab_size);
558 }
559 else
560 {
561 q_matmul(xout: s.logits.get(), x: s.xq.get(), w: w.wcls.get(), n: config.dim, d: config.vocab_size);
562 }
563 return s.logits.get();
564}
565
566void Tokenizer::build_tokenizer(const std::string &tokenizer_path, int size_for_vacab)
567{
568 vocab_size = size_for_vacab;
569 vocab.resize(new_size: vocab_size);
570 vocab_scores.resize(new_size: vocab_size);
571 for (int i = 0; i < 256; i++)
572 {
573 byte_pieces[i * 2] = static_cast<unsigned char>(i);
574 byte_pieces[i * 2 + 1] = '\0';
575 }
576 std::ifstream file(tokenizer_path, std::ios::binary);
577 if (!file)
578 {
579 std::cerr << "Couldn't open file " << tokenizer_path << '\n';
580 std::exit(EXIT_FAILURE);
581 }
582 file.read(s: reinterpret_cast<char *>(&max_token_length), n: sizeof(int));
583 int len = 0;
584 for (int i = 0; i < vocab_size; i++)
585 {
586 file.read(s: reinterpret_cast<char *>(&vocab_scores[i]), n: sizeof(float));
587 file.read(s: reinterpret_cast<char *>(&len), n: sizeof(int));
588 vocab[i] = std::make_unique<char[]>(num: len + 1);
589 file.read(s: vocab[i].get(), n: len);
590 vocab[i][len] = '\0';
591 }
592 file.close();
593}
594
595void Sampler::build_sampler(int vocab_size, float temperature, float topp, unsigned long long rng_seed)
596{
597 this->vocab_size = vocab_size;
598 this->temperature = temperature;
599 this->topp = topp;
600 rng_state = rng_seed;
601 // buffer only used with nucleus sampling; may not need but it's ~small
602 probindex = std::make_unique<ProbIndex[]>(num: vocab_size);
603}
604
605int Sampler::sample_argmax(float *probabilities, int n)
606{
607 // return the index that has the highest probability
608 int max_i = 0;
609 float max_p = probabilities[0];
610 for (int i = 1; i < n; i++)
611 {
612 if (probabilities[i] > max_p)
613 {
614 max_i = i;
615 max_p = probabilities[i];
616 }
617 }
618 return max_i;
619}
620
621int Sampler::sample_mult(float *probabilities, int n, float coin)
622{
623 // sample index from probabilities (they must sum to 1!)
624 // coin is a random number in [0, 1), usually from random_f32()
625 float cdf = 0.0f;
626 for (int i = 0; i < n; i++)
627 {
628 cdf += probabilities[i];
629 if (coin < cdf)
630 {
631 return i;
632 }
633 }
634 return n - 1; // in case of rounding errors
635}
636
637int Sampler::sample_topp(float *probabilities, int n, float topp, std::unique_ptr<ProbIndex[]> &probindex, float coin)
638{
639 // top-p sampling (or "nucleus sampling") samples from the smallest set of
640 // tokens that exceed probability topp. This way we never sample tokens that
641 // have very low probabilities and are less likely to go "off the rails".
642 // coin is a random number in [0, 1), usually from random_f32()
643
644 int n0 = 0;
645 // quicksort indices in descending order of probabilities
646 // values smaller than (1 - topp) / (n - 1) cannot be part of the result
647 // so for efficiency we crop these out as candidates before sorting
648 const float cutoff = (1.0f - topp) / (n - 1);
649 for (int i = 0; i < n; i++)
650 {
651 if (probabilities[i] >= cutoff)
652 {
653 probindex[n0].index = i;
654 probindex[n0].prob = probabilities[i];
655 n0++;
656 }
657 }
658 std::sort(first: probindex.get(), last: probindex.get() + n0, comp: compare_probindex);
659
660 // truncate the list where cumulative probability exceeds topp
661 float cumulative_prob = 0.0f;
662 int last_idx = n0 - 1; // in case of rounding errors consider all elements
663 for (int i = 0; i < n0; i++)
664 {
665 cumulative_prob += probindex[i].prob;
666 if (cumulative_prob > topp)
667 {
668 last_idx = i;
669 break; // we've exceeded topp by including last_idx
670 }
671 }
672
673 // sample from the truncated list
674 float r = coin * cumulative_prob;
675 float cdf = 0.0f;
676 for (int i = 0; i <= last_idx; i++)
677 {
678 cdf += probindex[i].prob;
679 if (r < cdf)
680 {
681 return probindex[i].index;
682 }
683 }
684 return probindex[last_idx].index; // in case of rounding errors
685}
686
687int Sampler::sample(float *logits)
688{
689 // sample the token given the logits and some hyperparameters
690 int next;
691 if (temperature == 0.0f)
692 {
693 // greedy argmax sampling: take the token with the highest probability
694 next = sample_argmax(probabilities: logits, n: vocab_size);
695 }
696 else
697 {
698 // apply the temperature to the logits
699 for (int q = 0; q < vocab_size; q++)
700 {
701 logits[q] /= temperature;
702 }
703 // apply softmax to the logits to get the probabilities for next token
704 softmax(x: logits, size: vocab_size);
705 // flip a (float) coin (this is our source of entropy for sampling)
706 float coin = random_f32(state: &rng_state);
707 // we sample from this distribution to get the next token
708 if (topp <= 0 || topp >= 1)
709 {
710 // simply sample from the predicted probability distribution
711 next = sample_mult(probabilities: logits, n: vocab_size, coin);
712 }
713 else
714 {
715 // top-p (nucleus) sampling, clamping the least likely tokens to zero
716 next = sample_topp(probabilities: logits, n: vocab_size, topp, probindex, coin);
717 }
718 }
719 return next;
720}
721
722void Tokenizer::encode(const std::string &text, const int8_t &bos, const int8_t &eos, std::unique_ptr<int[]> &tokens, int &n_tokens)
723{
724 // encode the string text (input) into an upper-bound preallocated tokens[] array
725 // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
726 if (text.empty())
727 {
728 std::cerr << "cannot encode NULL text" << '\n';
729 std::exit(EXIT_FAILURE);
730 }
731
732 if (!sorted_vocab)
733 {
734 // lazily malloc and sort the vocabulary
735 sorted_vocab = std::make_unique<TokenIndex[]>(num: vocab_size);
736 for (int i = 0; i < vocab_size; i++)
737 {
738 sorted_vocab[i].str = std::string(vocab[i].get());
739 sorted_vocab[i].id = i;
740 }
741 std::sort(first: sorted_vocab.get(), last: sorted_vocab.get() + vocab_size, comp: compare_tokens);
742 }
743
744 // create a temporary buffer that will store merge candidates of always two consecutive tokens
745 // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
746 std::string str_buffer;
747 str_buffer.resize(n: max_token_length * 2 + 1 + 2);
748 size_t str_len = 0;
749
750 // start at 0 tokens
751 n_tokens = 0;
752
753 // add optional BOS (=1) token, if desired
754 if (bos)
755 tokens[(n_tokens)++] = 1;
756
757 // add_dummy_prefix is true by default
758 // so prepend a dummy prefix token to the input string, but only if text != ""
759 // TODO: pretty sure this isn't correct in the general case but I don't have the
760 // energy to read more of the sentencepiece code to figure out what it's doing
761 if (text[0] != '\0')
762 {
763 int dummy_prefix = str_lookup(str: " ", sorted_vocab, vocab_size);
764 tokens[(n_tokens)++] = dummy_prefix;
765 }
766
767 // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
768 // Code point ↔ UTF-8 conversion
769 // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
770 // U+0000 U+007F 0xxxxxxx
771 // U+0080 U+07FF 110xxxxx 10xxxxxx
772 // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
773 // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
774
775 // process the raw (UTF-8) byte sequence of the input string
776 for (const char *c = text.c_str(); *c != '\0'; c++)
777 {
778 // reset buffer if the current byte is ASCII or a leading byte
779 // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
780 // 0x80 is 10000000
781 // in UTF-8, all continuation bytes start with "10" in first two bits
782 // so in English this is: "if this byte is not a continuation byte"
783 if ((*c & 0xC0) != 0x80)
784 {
785 // this byte must be either a leading byte (11...) or an ASCII char (0x...)
786 // => reset our location, as we're starting a new UTF-8 codepoint
787 str_len = 0;
788 }
789
790 // append the current byte to the buffer
791 str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
792 str_buffer[str_len] = '\0';
793
794 // while the next character is a continuation byte, continue appending
795 // but if there are too many of them, just stop to avoid overrunning str_buffer size.
796 if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4)
797 {
798 continue;
799 }
800
801 // ok c+1 is not a continuation byte, so we've read in a full codepoint
802 int id = str_lookup(str: str_buffer, sorted_vocab, vocab_size);
803
804 if (id != -1)
805 {
806 // we found this codepoint in vocab, add it as a token
807 tokens[(n_tokens)++] = id;
808 }
809 else
810 {
811 // byte_fallback encoding: just encode each byte as a token
812 // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
813 // so the individual bytes only start at index 3
814 for (decltype(str_len) i = 0; i < str_len; i++)
815 {
816 tokens[(n_tokens)++] = (unsigned char) str_buffer[i] + 3;
817 }
818 }
819 str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
820 }
821
822 // merge the best consecutive pair each iteration, according the scores in vocab_scores
823 while (1)
824 {
825 float best_score = -1e10;
826 int best_id = -1;
827 int best_idx = -1;
828
829 for (int i = 0; i < (n_tokens - 1); i++)
830 {
831 // check if we can merge the pair (tokens[i], tokens[i+1])
832 // sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
833 int id = str_lookup(str: str_buffer, sorted_vocab, vocab_size);
834 if (id != -1 && vocab_scores[id] > best_score)
835 {
836 // this merge pair exists in vocab! record its score and position
837 best_score = vocab_scores[id];
838 best_id = id;
839 best_idx = i;
840 }
841 }
842
843 if (best_idx == -1)
844 {
845 break; // we couldn't find any more pairs to merge, so we're done
846 }
847
848 // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
849 tokens[best_idx] = best_id;
850 // delete token at position best_idx+1, shift the entire sequence back 1
851 for (int i = best_idx + 1; i < (n_tokens - 1); i++)
852 {
853 tokens[i] = tokens[i + 1];
854 }
855 (n_tokens)--; // token length decreased
856 }
857
858 // add optional EOS (=2) token, if desired
859 if (eos)
860 tokens[(n_tokens)++] = 2;
861}
862
863std::string Tokenizer::decode(int prev_token, int token)
864{
865 char *piece = vocab[token].get();
866 // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
867 if (prev_token == 1 && piece[0] == ' ')
868 {
869 piece++;
870 }
871 // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
872 // parse this and convert and return the actual byte
873 unsigned char byte_val;
874 if (sscanf(buffer: piece, format: "<0x%02hhX>", &byte_val) == 1)
875 {
876 piece = (char *) byte_pieces + byte_val * 2;
877 }
878 return std::string(piece);
879}
880
881template<typename T>
882void generate(Transformer<T> &transformer, Tokenizer &tokenizer, Sampler &sampler, std::string &prompt, int steps)
883{
884 std::string empty_prompt(1, '\0');
885 if (prompt.empty())
886 {
887 prompt = empty_prompt;
888 }
889 // encode the (string) prompt into tokens sequence
890 int num_prompt_tokens = 0;
891 std::unique_ptr<int[]> prompt_tokens = std::make_unique<int[]>(num: strlen(s: prompt.c_str()) + 3); // +3 for '\0', ?BOS, ?EOS
892 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
893 tokenizer.encode(text: prompt, bos: 1, eos: 0, tokens&: prompt_tokens, n_tokens&: num_prompt_tokens);
894 if (num_prompt_tokens < 1)
895 {
896 std::cerr << "something is wrong, expected at least 1 prompt token" << '\n';
897 std::exit(EXIT_FAILURE);
898 }
899 // start the main loop
900 long start = 0; // used to time our code, only initialized after first iteration
901 int next; // will store the next token in the sequence
902 int token = prompt_tokens[0]; // kick off with the first token in the prompt
903 int pos = 0; // position in the sequence
904 while (pos < steps)
905 {
906 // forward the transformer to get logits for the next token
907 float *logits = transformer.forward(token, pos);
908
909 // advance the state machine
910 if (pos < num_prompt_tokens - 1)
911 {
912 // if we are still processing the input prompt, force the next prompt token
913 next = prompt_tokens[pos + 1];
914 }
915 else
916 {
917 // otherwise sample the next token from the logits
918 next = sampler.sample(logits);
919 }
920 pos++;
921
922 // data-dependent terminating condition: the BOS (=1) token delimits sequences
923 if (next == 1)
924 {
925 break;
926 }
927
928 // print the token as string, decode it with the Tokenizer object
929 std::string piece = tokenizer.decode(prev_token: token, token: next);
930 safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
931 fflush(stream: stdout);
932 token = next;
933
934 // init the timer here because the first iteration can be slower
935 if (start == 0)
936 {
937 start = time_in_ms();
938 }
939 }
940 std::cout << "\n";
941
942 // report achieved tok/s (pos-1 because the timer starts after first iteration)
943 if (pos > 1)
944 {
945 long end = time_in_ms();
946 std::cerr << "achieved tok/s: " << (pos - 1) / static_cast<double>(end - start) * 1000 << std::endl;
947 }
948}
949
950template<typename T>
951void chat(Transformer<T> &transformer, Tokenizer &tokenizer, Sampler &sampler, std::string &cli_user_prompt, std::string &cli_system_prompt, int steps)
952{
953 // buffers for reading the system prompt and user prompt from stdin
954 // you'll notice they are somewhat haphazardly and unsafely set atm
955 std::string system_prompt;
956 std::string user_prompt;
957 std::string rendered_prompt;
958 int num_prompt_tokens = 0;
959 std::unique_ptr<int[]> prompt_tokens = std::make_unique<int[]>(num: 1152);
960 int user_idx;
961
962 // start the main loop
963 int8_t user_turn = 1; // user starts
964 int next; // will store the next token in the sequence
965 int token; // stores the current token to feed into the transformer
966 int pos = 0; // position in the sequence
967 while (pos < steps)
968 {
969
970 // when it is the user's turn to contribute tokens to the dialog...
971 if (user_turn)
972 {
973 // get the (optional) system prompt at position 0
974 if (pos == 0)
975 {
976 // at position 0, the user can also contribute a system prompt
977 if (cli_system_prompt.empty())
978 {
979 // system prompt was not passed in, attempt to get it from stdin
980 read_stdin(guide: "Enter system prompt (optional): ", buffer&: system_prompt, max_len: sizeof(system_prompt));
981 }
982 else
983 {
984 // system prompt was passed in, use it
985 system_prompt = cli_system_prompt;
986 }
987 }
988 // get the user prompt
989 if (pos == 0 && !cli_user_prompt.empty())
990 {
991 // user prompt for position 0 was passed in, use it
992 user_prompt = cli_user_prompt;
993 }
994 else
995 {
996 // otherwise get user prompt from stdin
997 read_stdin(guide: "User: ", buffer&: user_prompt, max_len: sizeof(user_prompt));
998 }
999 // render user/system prompts into the Llama 2 Chat schema
1000 if (pos == 0 && !system_prompt.empty())
1001 {
1002 std::string system_template = "[INST] <<SYS>>\n" + system_prompt + "\n<</SYS>>\n\n" + user_prompt + " [/INST]";
1003 rendered_prompt = system_template;
1004 }
1005 else
1006 {
1007 std::string user_template = "[INST] " + user_prompt + " [/INST]";
1008 rendered_prompt = user_template;
1009 }
1010 // encode the rendered prompt into tokens
1011 tokenizer.encode(text: rendered_prompt, bos: 1, eos: 0, tokens&: prompt_tokens, n_tokens&: num_prompt_tokens);
1012 user_idx = 0; // reset the user index
1013 user_turn = 0;
1014 std::cout << "Assistant: ";
1015 }
1016
1017 // determine the token to pass into the transformer next
1018 if (user_idx < num_prompt_tokens)
1019 {
1020 // if we are still processing the input prompt, force the next prompt token
1021 token = prompt_tokens[user_idx++];
1022 }
1023 else
1024 {
1025 // otherwise use the next token sampled from previous turn
1026 token = next;
1027 }
1028 // EOS (=2) token ends the Assistant turn
1029 if (token == 2)
1030 {
1031 user_turn = 1;
1032 }
1033
1034 // forward the transformer to get logits for the next token
1035 float *logits = transformer.forward(token, pos);
1036 next = sampler.sample(logits);
1037 pos++;
1038
1039 if (user_idx >= num_prompt_tokens && next != 2)
1040 {
1041 // the Assistant is responding, so print its output
1042 std::string piece = tokenizer.decode(prev_token: token, token: next);
1043 safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
1044 fflush(stream: stdout);
1045 }
1046 if (next == 2)
1047 {
1048 std::cout << "\n";
1049 }
1050 }
1051 std::cout << "\n";
1052}
1053
1054int main(int argc, char *argv[])
1055{
1056 std::string checkpoint_path;
1057 std::string tokenizer_path = "/initrd/assets/tokenizer.bin";
1058 float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
1059 float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
1060 int steps = 256; // number of steps to run for
1061 std::string prompt;
1062 unsigned long long rng_seed = 0; // seed rng with time by default
1063 std::string mode = "generate";
1064 std::string system_prompt;
1065 if (argc >= 2)
1066 {
1067 checkpoint_path = argv[1];
1068 }
1069 else
1070 {
1071 checkpoint_path = "/initrd/assets/stories15M.bin";
1072 }
1073
1074 for (int i = 2; i < argc; i += 2)
1075 {
1076 // do some basic validation
1077 if (i + 1 >= argc)
1078 {
1079 error_usage();
1080 } // must have arg after flag
1081 if (argv[i][0] != '-')
1082 {
1083 error_usage();
1084 } // must start with dash
1085 if (strlen(s: argv[i]) != 2)
1086 {
1087 error_usage();
1088 } // must be -x (one dash, one letter)
1089 // read in the args
1090 if (argv[i][1] == 't')
1091 {
1092 temperature = atof(string: argv[i + 1]);
1093 }
1094 else if (argv[i][1] == 'p')
1095 {
1096 topp = atof(string: argv[i + 1]);
1097 }
1098 else if (argv[i][1] == 's')
1099 {
1100 rng_seed = atoi(string: argv[i + 1]);
1101 }
1102 else if (argv[i][1] == 'n')
1103 {
1104 steps = atoi(string: argv[i + 1]);
1105 }
1106 else if (argv[i][1] == 'i')
1107 {
1108 prompt = argv[i + 1];
1109 }
1110 else if (argv[i][1] == 'z')
1111 {
1112 tokenizer_path = argv[i + 1];
1113 }
1114 else if (argv[i][1] == 'm')
1115 {
1116 mode = argv[i + 1];
1117 }
1118 else if (argv[i][1] == 'y')
1119 {
1120 system_prompt = argv[i + 1];
1121 }
1122 else
1123 {
1124 error_usage();
1125 }
1126 }
1127 if (rng_seed <= 0)
1128 rng_seed = (unsigned int) time(NULL);
1129 if (temperature < 0.0)
1130 temperature = 0.0;
1131 if (topp < 0.0 || 1.0 < topp)
1132 topp = 0.9;
1133 if (steps < 0)
1134 steps = 0;
1135 if (is_quantized_model(checkpoint_path))
1136 {
1137 Transformer<QuantizedTensor> transformer;
1138 transformer.load_model(checkpoint_path);
1139 if (steps == 0 || steps > transformer.config.seq_len)
1140 steps = transformer.config.seq_len; // ovrerride to ~max length
1141 Tokenizer tokenizer;
1142 tokenizer.build_tokenizer(tokenizer_path, size_for_vacab: transformer.config.vocab_size);
1143 Sampler sampler;
1144 sampler.build_sampler(vocab_size: transformer.config.vocab_size, temperature, topp, rng_seed);
1145
1146 // run!
1147 if (mode == "generate")
1148 {
1149 generate(transformer, tokenizer, sampler, prompt, steps);
1150 }
1151 else if (mode == "chat")
1152 {
1153 chat(transformer, tokenizer, sampler, cli_user_prompt&: prompt, cli_system_prompt&: system_prompt, steps);
1154 }
1155 else
1156 {
1157 std::cerr << "unknown mode: " << mode << "\n" << std::endl;
1158 error_usage();
1159 }
1160 }
1161 else
1162 {
1163 Transformer<float> transformer;
1164 transformer.load_model(checkpoint_path);
1165 if (steps == 0 || steps > transformer.config.seq_len)
1166 steps = transformer.config.seq_len; // ovrerride to ~max length
1167 Tokenizer tokenizer;
1168 tokenizer.build_tokenizer(tokenizer_path, size_for_vacab: transformer.config.vocab_size);
1169 Sampler sampler;
1170 sampler.build_sampler(vocab_size: transformer.config.vocab_size, temperature, topp, rng_seed);
1171
1172 // run!
1173 if (mode == "generate")
1174 {
1175 generate(transformer, tokenizer, sampler, prompt, steps);
1176 }
1177 else if (mode == "chat")
1178 {
1179 chat(transformer, tokenizer, sampler, cli_user_prompt&: prompt, cli_system_prompt&: system_prompt, steps);
1180 }
1181 else
1182 {
1183 std::cerr << "unknown mode: " << mode << "\n" << std::endl;
1184 error_usage();
1185 }
1186 }
1187
1188 return 0;
1189}
1190