| 1 | #include "run.h" |
| 2 | int GS = 0; |
| 3 | |
| 4 | void error_usage() |
| 5 | { |
| 6 | std::cerr << R"(Usage: run <checkpoint> [options] |
| 7 | Example: run model.bin -n 256 -i "Once upon a time" |
| 8 | Options: |
| 9 | -t <float> temperature in [0,inf], default 1.0 |
| 10 | -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9 |
| 11 | -s <int> random seed, default time(NULL) |
| 12 | -n <int> number of steps to run for, default 256. 0 = max_seq_len |
| 13 | -i <string> input prompt |
| 14 | -z <string> optional path to custom tokenizer |
| 15 | -m <string> mode: generate|chat, default: generate |
| 16 | -y <string> (optional) system prompt in chat mode |
| 17 | )" ; |
| 18 | std::exit(EXIT_FAILURE); |
| 19 | } |
| 20 | |
| 21 | template<> |
| 22 | void Transformer<float>::malloc_weights() |
| 23 | { |
| 24 | int head_size = config.dim / config.n_heads; |
| 25 | // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models |
| 26 | unsigned long long n_layers = config.n_layers; |
| 27 | w.token_embedding_table = std::make_unique<float[]>(num: config.vocab_size * config.dim); |
| 28 | w.rms_att_weight = std::make_unique<float[]>(num: n_layers * config.dim); |
| 29 | w.wq = std::make_unique<float[]>(num: n_layers * config.dim * config.n_heads * head_size); |
| 30 | w.wk = std::make_unique<float[]>(num: n_layers * config.dim * config.n_kv_heads * head_size); |
| 31 | w.wv = std::make_unique<float[]>(num: n_layers * config.dim * config.n_kv_heads * head_size); |
| 32 | w.wo = std::make_unique<float[]>(num: n_layers * config.dim * config.n_heads * head_size); |
| 33 | w.rms_ffn_weight = std::make_unique<float[]>(num: n_layers * config.dim); |
| 34 | w.w1 = std::make_unique<float[]>(num: n_layers * config.dim * config.hidden_dim); |
| 35 | w.w2 = std::make_unique<float[]>(num: n_layers * config.dim * config.hidden_dim); |
| 36 | w.w3 = std::make_unique<float[]>(num: n_layers * config.dim * config.hidden_dim); |
| 37 | w.rms_final_weight = std::make_unique<float[]>(num: config.dim); |
| 38 | if (!shared_weights) |
| 39 | { |
| 40 | w.wcls = std::make_unique<float[]>(num: config.vocab_size * config.dim); |
| 41 | if (!w.wcls.get()) |
| 42 | { |
| 43 | std::cerr << "Malloc for wcls weights failed.\n" ; |
| 44 | std::exit(EXIT_FAILURE); |
| 45 | } |
| 46 | } |
| 47 | if (!w.token_embedding_table.get() || !w.rms_att_weight.get() || !w.wq.get() || !w.wk.get() || !w.wv.get() || !w.wo.get() || !w.rms_ffn_weight.get() || !w.w1.get() || |
| 48 | !w.w2.get() || !w.w3.get() || !w.rms_final_weight.get()) |
| 49 | { |
| 50 | std::cerr << "Malloc for weights failed.\n" ; |
| 51 | std::exit(EXIT_FAILURE); |
| 52 | } |
| 53 | } |
| 54 | |
| 55 | template<> |
| 56 | void Transformer<QuantizedTensor>::malloc_weights() |
| 57 | { |
| 58 | // int head_size = config.dim / config.n_heads; |
| 59 | // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models |
| 60 | unsigned long long n_layers = config.n_layers; |
| 61 | w.token_embedding_table = std::make_unique<float[]>(num: config.vocab_size * config.dim); |
| 62 | w.rms_att_weight = std::make_unique<float[]>(num: n_layers * config.dim); |
| 63 | w.wq = std::make_unique<QuantizedTensor[]>(num: n_layers); |
| 64 | w.wk = std::make_unique<QuantizedTensor[]>(num: n_layers); |
| 65 | w.wv = std::make_unique<QuantizedTensor[]>(num: n_layers); |
| 66 | w.wo = std::make_unique<QuantizedTensor[]>(num: n_layers); |
| 67 | w.rms_ffn_weight = std::make_unique<float[]>(num: n_layers * config.dim); |
| 68 | w.w1 = std::make_unique<QuantizedTensor[]>(num: n_layers); |
| 69 | w.w2 = std::make_unique<QuantizedTensor[]>(num: n_layers); |
| 70 | w.w3 = std::make_unique<QuantizedTensor[]>(num: n_layers); |
| 71 | w.rms_final_weight = std::make_unique<float[]>(num: config.dim); |
| 72 | |
| 73 | w.q_tokens = std::make_unique<QuantizedTensor[]>(num: 1); |
| 74 | |
| 75 | if (!shared_weights) |
| 76 | { |
| 77 | w.wcls = std::make_unique<QuantizedTensor[]>(num: 1); |
| 78 | if (!w.wcls.get()) |
| 79 | { |
| 80 | std::cerr << "Malloc for wcls weights failed.\n" ; |
| 81 | std::exit(EXIT_FAILURE); |
| 82 | } |
| 83 | } |
| 84 | if (!w.token_embedding_table.get() || !w.rms_att_weight.get() || !w.wq.get() || !w.q_tokens.get() || !w.wk.get() || !w.wv.get() || !w.wo.get() || |
| 85 | !w.rms_ffn_weight.get() || !w.w1.get() || !w.w2.get() || !w.w3.get() || !w.rms_final_weight.get()) |
| 86 | { |
| 87 | std::cerr << "Malloc for weights failed.\n" ; |
| 88 | std::exit(EXIT_FAILURE); |
| 89 | } |
| 90 | } |
| 91 | |
| 92 | template<> |
| 93 | void Transformer<float>::malloc_run_state() |
| 94 | { |
| 95 | int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; |
| 96 | s.x = std::make_unique<float[]>(num: config.dim); |
| 97 | s.xb = std::make_unique<float[]>(num: config.dim); |
| 98 | s.xb2 = std::make_unique<float[]>(num: config.dim); |
| 99 | s.hb = std::make_unique<float[]>(num: config.hidden_dim); |
| 100 | s.hb2 = std::make_unique<float[]>(num: config.hidden_dim); |
| 101 | s.q = std::make_unique<float[]>(num: config.dim); |
| 102 | s.k = std::make_unique<float[]>(num: kv_dim); |
| 103 | s.v = std::make_unique<float[]>(num: kv_dim); |
| 104 | s.att = std::make_unique<float[]>(num: config.seq_len * config.n_heads); |
| 105 | s.logits = std::make_unique<float[]>(num: config.vocab_size); |
| 106 | s.key_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim); |
| 107 | s.value_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim); |
| 108 | if (!s.x.get() || !s.xb.get() || !s.xb2.get() || !s.hb.get() || !s.hb2.get() || !s.q.get() || !s.k.get() || !s.v.get() || !s.att.get() || !s.logits.get() || |
| 109 | !s.key_cache.get() || !s.value_cache.get()) |
| 110 | { |
| 111 | std::cerr << "Malloc for run state failed.\n" ; |
| 112 | std::exit(EXIT_FAILURE); |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | template<> |
| 117 | void Transformer<QuantizedTensor>::malloc_run_state() |
| 118 | { |
| 119 | int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; |
| 120 | s.x = std::make_unique<float[]>(num: config.dim); |
| 121 | s.xb = std::make_unique<float[]>(num: config.dim); |
| 122 | s.xb2 = std::make_unique<float[]>(num: config.dim); |
| 123 | s.hb = std::make_unique<float[]>(num: config.hidden_dim); |
| 124 | s.hb2 = std::make_unique<float[]>(num: config.hidden_dim); |
| 125 | s.q = std::make_unique<float[]>(num: config.dim); |
| 126 | s.k = std::make_unique<float[]>(num: kv_dim); |
| 127 | s.v = std::make_unique<float[]>(num: kv_dim); |
| 128 | s.att = std::make_unique<float[]>(num: config.seq_len * config.n_heads); |
| 129 | s.logits = std::make_unique<float[]>(num: config.vocab_size); |
| 130 | s.key_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim); |
| 131 | s.value_cache = std::make_unique<float[]>(num: config.n_layers * config.seq_len * kv_dim); |
| 132 | if (!s.x.get() || !s.xb.get() || !s.xb2.get() || !s.hb.get() || !s.hb2.get() || !s.q.get() || !s.k.get() || !s.v.get() || !s.att.get() || !s.logits.get() || |
| 133 | !s.key_cache.get() || !s.value_cache.get()) |
| 134 | { |
| 135 | std::cerr << "Malloc for run state failed.\n" ; |
| 136 | std::exit(EXIT_FAILURE); |
| 137 | } |
| 138 | |
| 139 | s.xq = std::make_unique<QuantizedTensor[]>(num: 1); |
| 140 | s.hq = std::make_unique<QuantizedTensor[]>(num: 1); |
| 141 | if (!s.xq.get() || !s.hq.get()) |
| 142 | { |
| 143 | std::cerr << "Malloc for run state xq or hq failed.\n" ; |
| 144 | std::exit(EXIT_FAILURE); |
| 145 | } |
| 146 | s.xq[0].q = std::make_unique<int8_t[]>(num: config.dim); |
| 147 | s.xq[0].s = std::make_unique<float[]>(num: config.dim); |
| 148 | if (!s.xq[0].q.get() || !s.xq[0].s.get()) |
| 149 | { |
| 150 | std::cerr << "Malloc for run state xq[0] failed.\n" ; |
| 151 | std::exit(EXIT_FAILURE); |
| 152 | } |
| 153 | s.hq[0].q = std::make_unique<int8_t[]>(num: config.hidden_dim); |
| 154 | s.hq[0].s = std::make_unique<float[]>(num: config.hidden_dim); |
| 155 | if (!s.hq[0].q.get() || !s.hq[0].s.get()) |
| 156 | { |
| 157 | std::cerr << "Malloc for run state hq[0] failed.\n" ; |
| 158 | std::exit(EXIT_FAILURE); |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | template<> |
| 163 | void Transformer<float>::load_model(const std::string &checkpoint_path) |
| 164 | { |
| 165 | std::ifstream file(checkpoint_path, std::ios::binary); |
| 166 | if (!file) |
| 167 | { |
| 168 | std::cerr << "Couldn't open file " << checkpoint_path << '\n'; |
| 169 | std::exit(EXIT_FAILURE); |
| 170 | } |
| 171 | // 60816028 bytes |
| 172 | file.read(s: reinterpret_cast<char *>(&config), n: sizeof(Config)); |
| 173 | shared_weights = config.vocab_size > 0 ? 1 : 0; |
| 174 | config.vocab_size = std::abs(number: config.vocab_size); |
| 175 | malloc_weights(); |
| 176 | int head_size = config.dim / config.n_heads; |
| 177 | unsigned long long n_layers = config.n_layers; |
| 178 | |
| 179 | file.read(s: reinterpret_cast<char *>(w.token_embedding_table.get()), n: config.vocab_size * config.dim * sizeof(float)); |
| 180 | file.read(s: reinterpret_cast<char *>(w.rms_att_weight.get()), n: config.n_layers * config.dim * sizeof(float)); |
| 181 | file.read(s: reinterpret_cast<char *>(w.wq.get()), n: n_layers * config.dim * config.n_heads * head_size * sizeof(float)); |
| 182 | file.read(s: reinterpret_cast<char *>(w.wk.get()), n: n_layers * config.dim * config.n_kv_heads * head_size * sizeof(float)); |
| 183 | file.read(s: reinterpret_cast<char *>(w.wv.get()), n: n_layers * config.dim * config.n_kv_heads * head_size * sizeof(float)); |
| 184 | file.read(s: reinterpret_cast<char *>(w.wo.get()), n: n_layers * config.dim * config.n_heads * head_size * sizeof(float)); |
| 185 | file.read(s: reinterpret_cast<char *>(w.rms_ffn_weight.get()), n: n_layers * config.dim * sizeof(float)); |
| 186 | file.read(s: reinterpret_cast<char *>(w.w1.get()), n: n_layers * config.dim * config.hidden_dim * sizeof(float)); |
| 187 | file.read(s: reinterpret_cast<char *>(w.w2.get()), n: n_layers * config.dim * config.hidden_dim * sizeof(float)); |
| 188 | file.read(s: reinterpret_cast<char *>(w.w3.get()), n: n_layers * config.dim * config.hidden_dim * sizeof(float)); |
| 189 | file.read(s: reinterpret_cast<char *>(w.rms_final_weight.get()), n: config.dim * sizeof(float)); |
| 190 | |
| 191 | if (!shared_weights) |
| 192 | { |
| 193 | file.seekg((config.seq_len * head_size) * sizeof(float), std::ios::cur); |
| 194 | file.read(s: reinterpret_cast<char *>(w.wcls.get()), n: config.vocab_size * config.dim * sizeof(float)); |
| 195 | } |
| 196 | file.close(); |
| 197 | malloc_run_state(); |
| 198 | } |
| 199 | |
| 200 | template<> |
| 201 | void Transformer<QuantizedTensor>::load_model(const std::string &checkpoint_path) |
| 202 | { |
| 203 | |
| 204 | std::ifstream file(checkpoint_path, std::ios::binary); |
| 205 | if (!file) |
| 206 | { |
| 207 | std::cerr << "Couldn't open file " << checkpoint_path << '\n'; |
| 208 | std::exit(EXIT_FAILURE); |
| 209 | } |
| 210 | |
| 211 | uint32_t magic_number; |
| 212 | file.read(s: reinterpret_cast<char *>(&magic_number), n: sizeof(uint32_t)); |
| 213 | if (magic_number != 0x616b3432) |
| 214 | { |
| 215 | std::cerr << "Bad magic number\n" ; |
| 216 | std::exit(EXIT_FAILURE); |
| 217 | } |
| 218 | |
| 219 | int version; |
| 220 | file.read(s: reinterpret_cast<char *>(&version), n: sizeof(int)); |
| 221 | if (version != 2) |
| 222 | { |
| 223 | std::cerr << "Bad version " << version << ", need version 2\n" ; |
| 224 | std::exit(EXIT_FAILURE); |
| 225 | } |
| 226 | |
| 227 | file.read(s: reinterpret_cast<char *>(&config), n: sizeof(Config)); |
| 228 | |
| 229 | // read in flags |
| 230 | uint8_t shared_classifier; // a byte to indicate if the classifier is shared |
| 231 | file.read(s: reinterpret_cast<char *>(&shared_classifier), n: sizeof(uint8_t)); |
| 232 | |
| 233 | int group_size; |
| 234 | file.read(s: reinterpret_cast<char *>(&group_size), n: sizeof(int)); |
| 235 | GS = group_size; |
| 236 | |
| 237 | shared_weights = shared_classifier; |
| 238 | // config.vocab_size = std::abs(config.vocab_size); |
| 239 | |
| 240 | malloc_weights(); |
| 241 | int head_size = config.dim / config.n_heads; |
| 242 | unsigned long long n_layers = config.n_layers; |
| 243 | |
| 244 | int = 256; |
| 245 | file.seekg(header_size, std::ios::beg); |
| 246 | file.read(s: reinterpret_cast<char *>(w.rms_att_weight.get()), n: config.n_layers * config.dim * sizeof(float)); |
| 247 | file.read(s: reinterpret_cast<char *>(w.rms_ffn_weight.get()), n: n_layers * config.dim * sizeof(float)); |
| 248 | file.read(s: reinterpret_cast<char *>(w.rms_final_weight.get()), n: config.dim * sizeof(float)); |
| 249 | init_quantized_tensors(file, w: w.q_tokens.get(), n_layers: 1, each_layer: config.vocab_size * config.dim); |
| 250 | dequantize(qx: w.q_tokens.get(), x: w.token_embedding_table.get(), n: config.vocab_size * config.dim); |
| 251 | |
| 252 | init_quantized_tensors(file, w: w.wq.get(), n_layers, each_layer: config.dim * config.n_heads * head_size); |
| 253 | init_quantized_tensors(file, w: w.wk.get(), n_layers, each_layer: config.dim * config.n_kv_heads * head_size); |
| 254 | init_quantized_tensors(file, w: w.wv.get(), n_layers, each_layer: config.dim * config.n_kv_heads * head_size); |
| 255 | init_quantized_tensors(file, w: w.wo.get(), n_layers, each_layer: config.dim * config.n_heads * head_size); |
| 256 | |
| 257 | init_quantized_tensors(file, w: w.w1.get(), n_layers, each_layer: config.dim * config.hidden_dim); |
| 258 | init_quantized_tensors(file, w: w.w2.get(), n_layers, each_layer: config.dim * config.hidden_dim); |
| 259 | init_quantized_tensors(file, w: w.w3.get(), n_layers, each_layer: config.dim * config.hidden_dim); |
| 260 | |
| 261 | if (!shared_weights) |
| 262 | { |
| 263 | init_quantized_tensors(file, w: w.wcls.get(), n_layers: 1, each_layer: config.dim * config.vocab_size); |
| 264 | } |
| 265 | file.close(); |
| 266 | malloc_run_state(); |
| 267 | } |
| 268 | |
| 269 | template<> |
| 270 | float *Transformer<float>::forward(int token, int pos) |
| 271 | { |
| 272 | int dim = config.dim; |
| 273 | int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; |
| 274 | int kv_mul = config.n_heads / config.n_kv_heads; // integer multiplier of the kv sharing in multiquery |
| 275 | int hidden_dim = config.hidden_dim; |
| 276 | int head_size = dim / config.n_heads; |
| 277 | // copy the token embedding into x |
| 278 | std::memcpy(dest: s.x.get(), src: w.token_embedding_table.get() + token * dim, size: dim * sizeof(*(s.x.get()))); |
| 279 | // forward all the layers |
| 280 | for (decltype(config.n_layers) l = 0; l < config.n_layers; l++) |
| 281 | { |
| 282 | // attention rmsnorm |
| 283 | rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_att_weight.get() + l * dim, size: dim); |
| 284 | |
| 285 | // qkv matmuls for this position |
| 286 | matmul(xout: s.q.get(), x: s.xb.get(), w: w.wq.get() + l * dim * dim, n: dim, d: dim); |
| 287 | matmul(xout: s.k.get(), x: s.xb.get(), w: w.wk.get() + l * dim * kv_dim, n: dim, d: kv_dim); |
| 288 | matmul(xout: s.v.get(), x: s.xb.get(), w: w.wv.get() + l * dim * kv_dim, n: dim, d: kv_dim); |
| 289 | |
| 290 | // RoPE relative positional encoding: complex-valued rotate q and k in each head |
| 291 | for (int i = 0; i < dim; i += 2) |
| 292 | { |
| 293 | int head_dim = i % head_size; |
| 294 | float freq = 1.0f / powf(x: 10000.0f, y: head_dim / (float) head_size); |
| 295 | float val = pos * freq; |
| 296 | float fcr = cosf(x: val); |
| 297 | float fci = sinf(x: val); |
| 298 | int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only |
| 299 | for (int v = 0; v < rotn; v++) |
| 300 | { |
| 301 | float *vec = v == 0 ? s.q.get() : s.k.get(); // the vector to rotate (query or key) |
| 302 | float v0 = vec[i]; |
| 303 | float v1 = vec[i + 1]; |
| 304 | vec[i] = v0 * fcr - v1 * fci; |
| 305 | vec[i + 1] = v0 * fci + v1 * fcr; |
| 306 | } |
| 307 | } |
| 308 | |
| 309 | // save key,value at this time step (pos) to our kv cache |
| 310 | int loff = l * config.seq_len * kv_dim; // kv cache layer offset for convenience |
| 311 | float *key_cache_row = s.key_cache.get() + loff + pos * kv_dim; |
| 312 | float *value_cache_row = s.value_cache.get() + loff + pos * kv_dim; |
| 313 | std::memcpy(dest: key_cache_row, src: s.k.get(), size: kv_dim * sizeof(*key_cache_row)); |
| 314 | std::memcpy(dest: value_cache_row, src: s.v.get(), size: kv_dim * sizeof(*value_cache_row)); |
| 315 | |
| 316 | // multihead attention. iterate over all heads |
| 317 | int h; |
| 318 | for (h = 0; h < config.n_heads; h++) |
| 319 | { |
| 320 | // get the query vector for this head |
| 321 | float *q = s.q.get() + h * head_size; |
| 322 | // attention scores for this head |
| 323 | float *att = s.att.get() + h * config.seq_len; |
| 324 | // iterate over all timesteps, including the current one |
| 325 | for (int t = 0; t <= pos; t++) |
| 326 | { |
| 327 | // get the key vector for this head and at this timestep |
| 328 | float *k = s.key_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size; |
| 329 | // calculate the attention score as the dot product of q and k |
| 330 | float score = 0.0f; |
| 331 | for (int i = 0; i < head_size; i++) |
| 332 | { |
| 333 | // q.shape = (1,head_size) k.shape= (head_size, n_head, seq_len) |
| 334 | score += q[i] * k[i]; |
| 335 | } |
| 336 | score /= sqrtf(x: head_size); |
| 337 | // save the score to the attention buffer |
| 338 | // att.shape = (n_heads, seq_len) |
| 339 | att[t] = score; |
| 340 | } |
| 341 | |
| 342 | // softmax the scores to get attention weights, from 0..pos inclusively |
| 343 | softmax(x: att, size: pos + 1); |
| 344 | |
| 345 | // weighted sum of the values, store back into xb |
| 346 | float *xb = s.xb.get() + h * head_size; |
| 347 | std::memset(dest: xb, c: 0, size: head_size * sizeof(float)); |
| 348 | for (int t = 0; t <= pos; t++) |
| 349 | { |
| 350 | // get the value vector for this head and at this timestep |
| 351 | float *v = s.value_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size; |
| 352 | // get the attention weight for this timestep |
| 353 | float a = att[t]; |
| 354 | // accumulate the weighted value into xb |
| 355 | for (int i = 0; i < head_size; i++) |
| 356 | { |
| 357 | xb[i] += a * v[i]; |
| 358 | } |
| 359 | } |
| 360 | } |
| 361 | |
| 362 | // final matmul to get the output of the attention |
| 363 | matmul(xout: s.xb2.get(), x: s.xb.get(), w: w.wo.get() + l * dim * dim, n: dim, d: dim); |
| 364 | |
| 365 | // residual connection back into x |
| 366 | for (int i = 0; i < dim; i++) |
| 367 | { |
| 368 | s.x[i] += s.xb2[i]; |
| 369 | } |
| 370 | |
| 371 | // ffn rmsnorm |
| 372 | rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_ffn_weight.get() + l * dim, size: dim); |
| 373 | |
| 374 | // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 375 | // first calculate self.w1(x) and self.w3(x) |
| 376 | matmul(xout: s.hb.get(), x: s.xb.get(), w: w.w1.get() + l * dim * hidden_dim, n: dim, d: hidden_dim); |
| 377 | matmul(xout: s.hb2.get(), x: s.xb.get(), w: w.w3.get() + l * dim * hidden_dim, n: dim, d: hidden_dim); |
| 378 | |
| 379 | // SwiGLU non-linearity |
| 380 | for (int i = 0; i < hidden_dim; i++) |
| 381 | { |
| 382 | float val = s.hb[i]; |
| 383 | // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid |
| 384 | val *= (1.0f / (1.0f + expf(x: -val))); |
| 385 | // elementwise multiply with w3(x) |
| 386 | val *= s.hb2[i]; |
| 387 | s.hb[i] = val; |
| 388 | } |
| 389 | |
| 390 | // final matmul to get the output of the ffn |
| 391 | matmul(xout: s.xb.get(), x: s.hb.get(), w: w.w2.get() + l * dim * hidden_dim, n: hidden_dim, d: dim); |
| 392 | |
| 393 | // residual connection |
| 394 | for (int i = 0; i < dim; i++) |
| 395 | { |
| 396 | s.x[i] += s.xb[i]; |
| 397 | } |
| 398 | } |
| 399 | // final rmsnorm |
| 400 | rmsnorm(o: s.x.get(), x: s.x.get(), weight: w.rms_final_weight.get(), size: dim); |
| 401 | // classifier into logits |
| 402 | |
| 403 | if (shared_weights) |
| 404 | { |
| 405 | // w.wcls = std::move(w.token_embedding_table); |
| 406 | matmul(xout: s.logits.get(), x: s.x.get(), w: w.token_embedding_table.get(), n: config.dim, d: config.vocab_size); |
| 407 | } |
| 408 | else |
| 409 | { |
| 410 | matmul(xout: s.logits.get(), x: s.x.get(), w: w.wcls.get(), n: config.dim, d: config.vocab_size); |
| 411 | } |
| 412 | return s.logits.get(); |
| 413 | } |
| 414 | |
| 415 | template<> |
| 416 | float *Transformer<QuantizedTensor>::forward(int token, int pos) |
| 417 | { |
| 418 | int dim = config.dim; |
| 419 | int kv_dim = (config.dim * config.n_kv_heads) / config.n_heads; |
| 420 | int kv_mul = config.n_heads / config.n_kv_heads; // integer multiplier of the kv sharing in multiquery |
| 421 | int hidden_dim = config.hidden_dim; |
| 422 | int head_size = dim / config.n_heads; |
| 423 | // copy the token embedding into x |
| 424 | std::memcpy(dest: s.x.get(), src: w.token_embedding_table.get() + token * dim, size: dim * sizeof(float)); |
| 425 | // forward all the layers |
| 426 | for (decltype(config.n_layers) l = 0; l < config.n_layers; l++) |
| 427 | { |
| 428 | // attention rmsnorm |
| 429 | rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_att_weight.get() + l * dim, size: dim); |
| 430 | |
| 431 | // qkv matmuls for this position |
| 432 | quantize(qx: s.xq.get(), x: s.xb.get(), n: dim); |
| 433 | q_matmul(xout: s.q.get(), x: s.xq.get(), w: w.wq.get() + l, n: dim, d: dim); |
| 434 | q_matmul(xout: s.k.get(), x: s.xq.get(), w: w.wk.get() + l, n: dim, d: kv_dim); |
| 435 | q_matmul(xout: s.v.get(), x: s.xq.get(), w: w.wv.get() + l, n: dim, d: kv_dim); |
| 436 | |
| 437 | // RoPE relative positional encoding: complex-valued rotate q and k in each head |
| 438 | for (int i = 0; i < dim; i += 2) |
| 439 | { |
| 440 | int head_dim = i % head_size; |
| 441 | float freq = 1.0f / powf(x: 10000.0f, y: head_dim / (float) head_size); |
| 442 | float val = pos * freq; |
| 443 | float fcr = cosf(x: val); |
| 444 | float fci = sinf(x: val); |
| 445 | int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only |
| 446 | for (int v = 0; v < rotn; v++) |
| 447 | { |
| 448 | float *vec = v == 0 ? s.q.get() : s.k.get(); // the vector to rotate (query or key) |
| 449 | float v0 = vec[i]; |
| 450 | float v1 = vec[i + 1]; |
| 451 | vec[i] = v0 * fcr - v1 * fci; |
| 452 | vec[i + 1] = v0 * fci + v1 * fcr; |
| 453 | } |
| 454 | } |
| 455 | |
| 456 | // save key,value at this time step (pos) to our kv cache |
| 457 | int loff = l * config.seq_len * kv_dim; // kv cache layer offset for convenience |
| 458 | float *key_cache_row = s.key_cache.get() + loff + pos * kv_dim; |
| 459 | float *value_cache_row = s.value_cache.get() + loff + pos * kv_dim; |
| 460 | std::memcpy(dest: key_cache_row, src: s.k.get(), size: kv_dim * sizeof(*key_cache_row)); |
| 461 | std::memcpy(dest: value_cache_row, src: s.v.get(), size: kv_dim * sizeof(*value_cache_row)); |
| 462 | |
| 463 | // multihead attention. iterate over all heads |
| 464 | int h; |
| 465 | for (h = 0; h < config.n_heads; h++) |
| 466 | { |
| 467 | // get the query vector for this head |
| 468 | float *q = s.q.get() + h * head_size; |
| 469 | // attention scores for this head |
| 470 | float *att = s.att.get() + h * config.seq_len; |
| 471 | // iterate over all timesteps, including the current one |
| 472 | for (int t = 0; t <= pos; t++) |
| 473 | { |
| 474 | // get the key vector for this head and at this timestep |
| 475 | float *k = s.key_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size; |
| 476 | // calculate the attention score as the dot product of q and k |
| 477 | float score = 0.0f; |
| 478 | for (int i = 0; i < head_size; i++) |
| 479 | { |
| 480 | // q.shape = (1,head_size) k.shape= (head_size, n_head, seq_len) |
| 481 | score += q[i] * k[i]; |
| 482 | } |
| 483 | score /= sqrtf(x: head_size); |
| 484 | // save the score to the attention buffer |
| 485 | // att.shape = (n_heads, seq_len) |
| 486 | att[t] = score; |
| 487 | } |
| 488 | |
| 489 | // softmax the scores to get attention weights, from 0..pos inclusively |
| 490 | softmax(x: att, size: pos + 1); |
| 491 | |
| 492 | // weighted sum of the values, store back into xb |
| 493 | float *xb = s.xb.get() + h * head_size; |
| 494 | std::memset(dest: xb, c: 0, size: head_size * sizeof(float)); |
| 495 | for (int t = 0; t <= pos; t++) |
| 496 | { |
| 497 | // get the value vector for this head and at this timestep |
| 498 | float *v = s.value_cache.get() + loff + t * kv_dim + (h / kv_mul) * head_size; |
| 499 | // get the attention weight for this timestep |
| 500 | float a = att[t]; |
| 501 | // accumulate the weighted value into xb |
| 502 | for (int i = 0; i < head_size; i++) |
| 503 | { |
| 504 | xb[i] += a * v[i]; |
| 505 | } |
| 506 | } |
| 507 | } |
| 508 | |
| 509 | quantize(qx: s.xq.get(), x: s.xb.get(), n: dim); |
| 510 | // final matmul to get the output of the attention |
| 511 | q_matmul(xout: s.xb2.get(), x: s.xq.get(), w: w.wo.get() + l, n: dim, d: dim); |
| 512 | |
| 513 | // residual connection back into x |
| 514 | for (int i = 0; i < dim; i++) |
| 515 | { |
| 516 | s.x[i] += s.xb2[i]; |
| 517 | } |
| 518 | |
| 519 | // ffn rmsnorm |
| 520 | rmsnorm(o: s.xb.get(), x: s.x.get(), weight: w.rms_ffn_weight.get() + l * dim, size: dim); |
| 521 | |
| 522 | // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 523 | // first calculate self.w1(x) and self.w3(x) |
| 524 | quantize(qx: s.xq.get(), x: s.xb.get(), n: dim); |
| 525 | q_matmul(xout: s.hb.get(), x: s.xq.get(), w: w.w1.get() + l, n: dim, d: hidden_dim); |
| 526 | q_matmul(xout: s.hb2.get(), x: s.xq.get(), w: w.w3.get() + l, n: dim, d: hidden_dim); |
| 527 | |
| 528 | // SwiGLU non-linearity |
| 529 | for (int i = 0; i < hidden_dim; i++) |
| 530 | { |
| 531 | float val = s.hb[i]; |
| 532 | // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid |
| 533 | val *= (1.0f / (1.0f + expf(x: -val))); |
| 534 | // elementwise multiply with w3(x) |
| 535 | val *= s.hb2[i]; |
| 536 | s.hb[i] = val; |
| 537 | } |
| 538 | |
| 539 | // final matmul to get the output of the ffn |
| 540 | quantize(qx: s.hq.get(), x: s.hb.get(), n: hidden_dim); |
| 541 | q_matmul(xout: s.xb.get(), x: s.hq.get(), w: w.w2.get() + l, n: hidden_dim, d: dim); |
| 542 | |
| 543 | // residual connection |
| 544 | for (int i = 0; i < dim; i++) |
| 545 | { |
| 546 | s.x[i] += s.xb[i]; |
| 547 | } |
| 548 | } |
| 549 | // final rmsnorm |
| 550 | rmsnorm(o: s.x.get(), x: s.x.get(), weight: w.rms_final_weight.get(), size: dim); |
| 551 | // classifier into logits |
| 552 | |
| 553 | quantize(qx: s.xq.get(), x: s.x.get(), n: dim); |
| 554 | if (shared_weights) |
| 555 | { |
| 556 | // w.wcls = std::move(w.token_embedding_table); |
| 557 | q_matmul(xout: s.logits.get(), x: s.xq.get(), w: w.q_tokens.get(), n: config.dim, d: config.vocab_size); |
| 558 | } |
| 559 | else |
| 560 | { |
| 561 | q_matmul(xout: s.logits.get(), x: s.xq.get(), w: w.wcls.get(), n: config.dim, d: config.vocab_size); |
| 562 | } |
| 563 | return s.logits.get(); |
| 564 | } |
| 565 | |
| 566 | void Tokenizer::build_tokenizer(const std::string &tokenizer_path, int size_for_vacab) |
| 567 | { |
| 568 | vocab_size = size_for_vacab; |
| 569 | vocab.resize(new_size: vocab_size); |
| 570 | vocab_scores.resize(new_size: vocab_size); |
| 571 | for (int i = 0; i < 256; i++) |
| 572 | { |
| 573 | byte_pieces[i * 2] = static_cast<unsigned char>(i); |
| 574 | byte_pieces[i * 2 + 1] = '\0'; |
| 575 | } |
| 576 | std::ifstream file(tokenizer_path, std::ios::binary); |
| 577 | if (!file) |
| 578 | { |
| 579 | std::cerr << "Couldn't open file " << tokenizer_path << '\n'; |
| 580 | std::exit(EXIT_FAILURE); |
| 581 | } |
| 582 | file.read(s: reinterpret_cast<char *>(&max_token_length), n: sizeof(int)); |
| 583 | int len = 0; |
| 584 | for (int i = 0; i < vocab_size; i++) |
| 585 | { |
| 586 | file.read(s: reinterpret_cast<char *>(&vocab_scores[i]), n: sizeof(float)); |
| 587 | file.read(s: reinterpret_cast<char *>(&len), n: sizeof(int)); |
| 588 | vocab[i] = std::make_unique<char[]>(num: len + 1); |
| 589 | file.read(s: vocab[i].get(), n: len); |
| 590 | vocab[i][len] = '\0'; |
| 591 | } |
| 592 | file.close(); |
| 593 | } |
| 594 | |
| 595 | void Sampler::build_sampler(int vocab_size, float temperature, float topp, unsigned long long rng_seed) |
| 596 | { |
| 597 | this->vocab_size = vocab_size; |
| 598 | this->temperature = temperature; |
| 599 | this->topp = topp; |
| 600 | rng_state = rng_seed; |
| 601 | // buffer only used with nucleus sampling; may not need but it's ~small |
| 602 | probindex = std::make_unique<ProbIndex[]>(num: vocab_size); |
| 603 | } |
| 604 | |
| 605 | int Sampler::sample_argmax(float *probabilities, int n) |
| 606 | { |
| 607 | // return the index that has the highest probability |
| 608 | int max_i = 0; |
| 609 | float max_p = probabilities[0]; |
| 610 | for (int i = 1; i < n; i++) |
| 611 | { |
| 612 | if (probabilities[i] > max_p) |
| 613 | { |
| 614 | max_i = i; |
| 615 | max_p = probabilities[i]; |
| 616 | } |
| 617 | } |
| 618 | return max_i; |
| 619 | } |
| 620 | |
| 621 | int Sampler::sample_mult(float *probabilities, int n, float coin) |
| 622 | { |
| 623 | // sample index from probabilities (they must sum to 1!) |
| 624 | // coin is a random number in [0, 1), usually from random_f32() |
| 625 | float cdf = 0.0f; |
| 626 | for (int i = 0; i < n; i++) |
| 627 | { |
| 628 | cdf += probabilities[i]; |
| 629 | if (coin < cdf) |
| 630 | { |
| 631 | return i; |
| 632 | } |
| 633 | } |
| 634 | return n - 1; // in case of rounding errors |
| 635 | } |
| 636 | |
| 637 | int Sampler::sample_topp(float *probabilities, int n, float topp, std::unique_ptr<ProbIndex[]> &probindex, float coin) |
| 638 | { |
| 639 | // top-p sampling (or "nucleus sampling") samples from the smallest set of |
| 640 | // tokens that exceed probability topp. This way we never sample tokens that |
| 641 | // have very low probabilities and are less likely to go "off the rails". |
| 642 | // coin is a random number in [0, 1), usually from random_f32() |
| 643 | |
| 644 | int n0 = 0; |
| 645 | // quicksort indices in descending order of probabilities |
| 646 | // values smaller than (1 - topp) / (n - 1) cannot be part of the result |
| 647 | // so for efficiency we crop these out as candidates before sorting |
| 648 | const float cutoff = (1.0f - topp) / (n - 1); |
| 649 | for (int i = 0; i < n; i++) |
| 650 | { |
| 651 | if (probabilities[i] >= cutoff) |
| 652 | { |
| 653 | probindex[n0].index = i; |
| 654 | probindex[n0].prob = probabilities[i]; |
| 655 | n0++; |
| 656 | } |
| 657 | } |
| 658 | std::sort(first: probindex.get(), last: probindex.get() + n0, comp: compare_probindex); |
| 659 | |
| 660 | // truncate the list where cumulative probability exceeds topp |
| 661 | float cumulative_prob = 0.0f; |
| 662 | int last_idx = n0 - 1; // in case of rounding errors consider all elements |
| 663 | for (int i = 0; i < n0; i++) |
| 664 | { |
| 665 | cumulative_prob += probindex[i].prob; |
| 666 | if (cumulative_prob > topp) |
| 667 | { |
| 668 | last_idx = i; |
| 669 | break; // we've exceeded topp by including last_idx |
| 670 | } |
| 671 | } |
| 672 | |
| 673 | // sample from the truncated list |
| 674 | float r = coin * cumulative_prob; |
| 675 | float cdf = 0.0f; |
| 676 | for (int i = 0; i <= last_idx; i++) |
| 677 | { |
| 678 | cdf += probindex[i].prob; |
| 679 | if (r < cdf) |
| 680 | { |
| 681 | return probindex[i].index; |
| 682 | } |
| 683 | } |
| 684 | return probindex[last_idx].index; // in case of rounding errors |
| 685 | } |
| 686 | |
| 687 | int Sampler::sample(float *logits) |
| 688 | { |
| 689 | // sample the token given the logits and some hyperparameters |
| 690 | int next; |
| 691 | if (temperature == 0.0f) |
| 692 | { |
| 693 | // greedy argmax sampling: take the token with the highest probability |
| 694 | next = sample_argmax(probabilities: logits, n: vocab_size); |
| 695 | } |
| 696 | else |
| 697 | { |
| 698 | // apply the temperature to the logits |
| 699 | for (int q = 0; q < vocab_size; q++) |
| 700 | { |
| 701 | logits[q] /= temperature; |
| 702 | } |
| 703 | // apply softmax to the logits to get the probabilities for next token |
| 704 | softmax(x: logits, size: vocab_size); |
| 705 | // flip a (float) coin (this is our source of entropy for sampling) |
| 706 | float coin = random_f32(state: &rng_state); |
| 707 | // we sample from this distribution to get the next token |
| 708 | if (topp <= 0 || topp >= 1) |
| 709 | { |
| 710 | // simply sample from the predicted probability distribution |
| 711 | next = sample_mult(probabilities: logits, n: vocab_size, coin); |
| 712 | } |
| 713 | else |
| 714 | { |
| 715 | // top-p (nucleus) sampling, clamping the least likely tokens to zero |
| 716 | next = sample_topp(probabilities: logits, n: vocab_size, topp, probindex, coin); |
| 717 | } |
| 718 | } |
| 719 | return next; |
| 720 | } |
| 721 | |
| 722 | void Tokenizer::encode(const std::string &text, const int8_t &bos, const int8_t &eos, std::unique_ptr<int[]> &tokens, int &n_tokens) |
| 723 | { |
| 724 | // encode the string text (input) into an upper-bound preallocated tokens[] array |
| 725 | // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2) |
| 726 | if (text.empty()) |
| 727 | { |
| 728 | std::cerr << "cannot encode NULL text" << '\n'; |
| 729 | std::exit(EXIT_FAILURE); |
| 730 | } |
| 731 | |
| 732 | if (!sorted_vocab) |
| 733 | { |
| 734 | // lazily malloc and sort the vocabulary |
| 735 | sorted_vocab = std::make_unique<TokenIndex[]>(num: vocab_size); |
| 736 | for (int i = 0; i < vocab_size; i++) |
| 737 | { |
| 738 | sorted_vocab[i].str = std::string(vocab[i].get()); |
| 739 | sorted_vocab[i].id = i; |
| 740 | } |
| 741 | std::sort(first: sorted_vocab.get(), last: sorted_vocab.get() + vocab_size, comp: compare_tokens); |
| 742 | } |
| 743 | |
| 744 | // create a temporary buffer that will store merge candidates of always two consecutive tokens |
| 745 | // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1) |
| 746 | std::string str_buffer; |
| 747 | str_buffer.resize(n: max_token_length * 2 + 1 + 2); |
| 748 | size_t str_len = 0; |
| 749 | |
| 750 | // start at 0 tokens |
| 751 | n_tokens = 0; |
| 752 | |
| 753 | // add optional BOS (=1) token, if desired |
| 754 | if (bos) |
| 755 | tokens[(n_tokens)++] = 1; |
| 756 | |
| 757 | // add_dummy_prefix is true by default |
| 758 | // so prepend a dummy prefix token to the input string, but only if text != "" |
| 759 | // TODO: pretty sure this isn't correct in the general case but I don't have the |
| 760 | // energy to read more of the sentencepiece code to figure out what it's doing |
| 761 | if (text[0] != '\0') |
| 762 | { |
| 763 | int dummy_prefix = str_lookup(str: " " , sorted_vocab, vocab_size); |
| 764 | tokens[(n_tokens)++] = dummy_prefix; |
| 765 | } |
| 766 | |
| 767 | // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: |
| 768 | // Code point ↔ UTF-8 conversion |
| 769 | // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 |
| 770 | // U+0000 U+007F 0xxxxxxx |
| 771 | // U+0080 U+07FF 110xxxxx 10xxxxxx |
| 772 | // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx |
| 773 | // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx |
| 774 | |
| 775 | // process the raw (UTF-8) byte sequence of the input string |
| 776 | for (const char *c = text.c_str(); *c != '\0'; c++) |
| 777 | { |
| 778 | // reset buffer if the current byte is ASCII or a leading byte |
| 779 | // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest |
| 780 | // 0x80 is 10000000 |
| 781 | // in UTF-8, all continuation bytes start with "10" in first two bits |
| 782 | // so in English this is: "if this byte is not a continuation byte" |
| 783 | if ((*c & 0xC0) != 0x80) |
| 784 | { |
| 785 | // this byte must be either a leading byte (11...) or an ASCII char (0x...) |
| 786 | // => reset our location, as we're starting a new UTF-8 codepoint |
| 787 | str_len = 0; |
| 788 | } |
| 789 | |
| 790 | // append the current byte to the buffer |
| 791 | str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line |
| 792 | str_buffer[str_len] = '\0'; |
| 793 | |
| 794 | // while the next character is a continuation byte, continue appending |
| 795 | // but if there are too many of them, just stop to avoid overrunning str_buffer size. |
| 796 | if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) |
| 797 | { |
| 798 | continue; |
| 799 | } |
| 800 | |
| 801 | // ok c+1 is not a continuation byte, so we've read in a full codepoint |
| 802 | int id = str_lookup(str: str_buffer, sorted_vocab, vocab_size); |
| 803 | |
| 804 | if (id != -1) |
| 805 | { |
| 806 | // we found this codepoint in vocab, add it as a token |
| 807 | tokens[(n_tokens)++] = id; |
| 808 | } |
| 809 | else |
| 810 | { |
| 811 | // byte_fallback encoding: just encode each byte as a token |
| 812 | // +3 is here because the first 3 vocab elements are <unk>, <s>, </s> |
| 813 | // so the individual bytes only start at index 3 |
| 814 | for (decltype(str_len) i = 0; i < str_len; i++) |
| 815 | { |
| 816 | tokens[(n_tokens)++] = (unsigned char) str_buffer[i] + 3; |
| 817 | } |
| 818 | } |
| 819 | str_len = 0; // protect against a sequence of stray UTF8 continuation bytes |
| 820 | } |
| 821 | |
| 822 | // merge the best consecutive pair each iteration, according the scores in vocab_scores |
| 823 | while (1) |
| 824 | { |
| 825 | float best_score = -1e10; |
| 826 | int best_id = -1; |
| 827 | int best_idx = -1; |
| 828 | |
| 829 | for (int i = 0; i < (n_tokens - 1); i++) |
| 830 | { |
| 831 | // check if we can merge the pair (tokens[i], tokens[i+1]) |
| 832 | // sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); |
| 833 | int id = str_lookup(str: str_buffer, sorted_vocab, vocab_size); |
| 834 | if (id != -1 && vocab_scores[id] > best_score) |
| 835 | { |
| 836 | // this merge pair exists in vocab! record its score and position |
| 837 | best_score = vocab_scores[id]; |
| 838 | best_id = id; |
| 839 | best_idx = i; |
| 840 | } |
| 841 | } |
| 842 | |
| 843 | if (best_idx == -1) |
| 844 | { |
| 845 | break; // we couldn't find any more pairs to merge, so we're done |
| 846 | } |
| 847 | |
| 848 | // merge the consecutive pair (best_idx, best_idx+1) into new token best_id |
| 849 | tokens[best_idx] = best_id; |
| 850 | // delete token at position best_idx+1, shift the entire sequence back 1 |
| 851 | for (int i = best_idx + 1; i < (n_tokens - 1); i++) |
| 852 | { |
| 853 | tokens[i] = tokens[i + 1]; |
| 854 | } |
| 855 | (n_tokens)--; // token length decreased |
| 856 | } |
| 857 | |
| 858 | // add optional EOS (=2) token, if desired |
| 859 | if (eos) |
| 860 | tokens[(n_tokens)++] = 2; |
| 861 | } |
| 862 | |
| 863 | std::string Tokenizer::decode(int prev_token, int token) |
| 864 | { |
| 865 | char *piece = vocab[token].get(); |
| 866 | // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) |
| 867 | if (prev_token == 1 && piece[0] == ' ') |
| 868 | { |
| 869 | piece++; |
| 870 | } |
| 871 | // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' |
| 872 | // parse this and convert and return the actual byte |
| 873 | unsigned char byte_val; |
| 874 | if (sscanf(buffer: piece, format: "<0x%02hhX>" , &byte_val) == 1) |
| 875 | { |
| 876 | piece = (char *) byte_pieces + byte_val * 2; |
| 877 | } |
| 878 | return std::string(piece); |
| 879 | } |
| 880 | |
| 881 | template<typename T> |
| 882 | void generate(Transformer<T> &transformer, Tokenizer &tokenizer, Sampler &sampler, std::string &prompt, int steps) |
| 883 | { |
| 884 | std::string empty_prompt(1, '\0'); |
| 885 | if (prompt.empty()) |
| 886 | { |
| 887 | prompt = empty_prompt; |
| 888 | } |
| 889 | // encode the (string) prompt into tokens sequence |
| 890 | int num_prompt_tokens = 0; |
| 891 | std::unique_ptr<int[]> prompt_tokens = std::make_unique<int[]>(num: strlen(s: prompt.c_str()) + 3); // +3 for '\0', ?BOS, ?EOS |
| 892 | // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); |
| 893 | tokenizer.encode(text: prompt, bos: 1, eos: 0, tokens&: prompt_tokens, n_tokens&: num_prompt_tokens); |
| 894 | if (num_prompt_tokens < 1) |
| 895 | { |
| 896 | std::cerr << "something is wrong, expected at least 1 prompt token" << '\n'; |
| 897 | std::exit(EXIT_FAILURE); |
| 898 | } |
| 899 | // start the main loop |
| 900 | long start = 0; // used to time our code, only initialized after first iteration |
| 901 | int next; // will store the next token in the sequence |
| 902 | int token = prompt_tokens[0]; // kick off with the first token in the prompt |
| 903 | int pos = 0; // position in the sequence |
| 904 | while (pos < steps) |
| 905 | { |
| 906 | // forward the transformer to get logits for the next token |
| 907 | float *logits = transformer.forward(token, pos); |
| 908 | |
| 909 | // advance the state machine |
| 910 | if (pos < num_prompt_tokens - 1) |
| 911 | { |
| 912 | // if we are still processing the input prompt, force the next prompt token |
| 913 | next = prompt_tokens[pos + 1]; |
| 914 | } |
| 915 | else |
| 916 | { |
| 917 | // otherwise sample the next token from the logits |
| 918 | next = sampler.sample(logits); |
| 919 | } |
| 920 | pos++; |
| 921 | |
| 922 | // data-dependent terminating condition: the BOS (=1) token delimits sequences |
| 923 | if (next == 1) |
| 924 | { |
| 925 | break; |
| 926 | } |
| 927 | |
| 928 | // print the token as string, decode it with the Tokenizer object |
| 929 | std::string piece = tokenizer.decode(prev_token: token, token: next); |
| 930 | safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes |
| 931 | fflush(stream: stdout); |
| 932 | token = next; |
| 933 | |
| 934 | // init the timer here because the first iteration can be slower |
| 935 | if (start == 0) |
| 936 | { |
| 937 | start = time_in_ms(); |
| 938 | } |
| 939 | } |
| 940 | std::cout << "\n" ; |
| 941 | |
| 942 | // report achieved tok/s (pos-1 because the timer starts after first iteration) |
| 943 | if (pos > 1) |
| 944 | { |
| 945 | long end = time_in_ms(); |
| 946 | std::cerr << "achieved tok/s: " << (pos - 1) / static_cast<double>(end - start) * 1000 << std::endl; |
| 947 | } |
| 948 | } |
| 949 | |
| 950 | template<typename T> |
| 951 | void chat(Transformer<T> &transformer, Tokenizer &tokenizer, Sampler &sampler, std::string &cli_user_prompt, std::string &cli_system_prompt, int steps) |
| 952 | { |
| 953 | // buffers for reading the system prompt and user prompt from stdin |
| 954 | // you'll notice they are somewhat haphazardly and unsafely set atm |
| 955 | std::string system_prompt; |
| 956 | std::string user_prompt; |
| 957 | std::string rendered_prompt; |
| 958 | int num_prompt_tokens = 0; |
| 959 | std::unique_ptr<int[]> prompt_tokens = std::make_unique<int[]>(num: 1152); |
| 960 | int user_idx; |
| 961 | |
| 962 | // start the main loop |
| 963 | int8_t user_turn = 1; // user starts |
| 964 | int next; // will store the next token in the sequence |
| 965 | int token; // stores the current token to feed into the transformer |
| 966 | int pos = 0; // position in the sequence |
| 967 | while (pos < steps) |
| 968 | { |
| 969 | |
| 970 | // when it is the user's turn to contribute tokens to the dialog... |
| 971 | if (user_turn) |
| 972 | { |
| 973 | // get the (optional) system prompt at position 0 |
| 974 | if (pos == 0) |
| 975 | { |
| 976 | // at position 0, the user can also contribute a system prompt |
| 977 | if (cli_system_prompt.empty()) |
| 978 | { |
| 979 | // system prompt was not passed in, attempt to get it from stdin |
| 980 | read_stdin(guide: "Enter system prompt (optional): " , buffer&: system_prompt, max_len: sizeof(system_prompt)); |
| 981 | } |
| 982 | else |
| 983 | { |
| 984 | // system prompt was passed in, use it |
| 985 | system_prompt = cli_system_prompt; |
| 986 | } |
| 987 | } |
| 988 | // get the user prompt |
| 989 | if (pos == 0 && !cli_user_prompt.empty()) |
| 990 | { |
| 991 | // user prompt for position 0 was passed in, use it |
| 992 | user_prompt = cli_user_prompt; |
| 993 | } |
| 994 | else |
| 995 | { |
| 996 | // otherwise get user prompt from stdin |
| 997 | read_stdin(guide: "User: " , buffer&: user_prompt, max_len: sizeof(user_prompt)); |
| 998 | } |
| 999 | // render user/system prompts into the Llama 2 Chat schema |
| 1000 | if (pos == 0 && !system_prompt.empty()) |
| 1001 | { |
| 1002 | std::string system_template = "[INST] <<SYS>>\n" + system_prompt + "\n<</SYS>>\n\n" + user_prompt + " [/INST]" ; |
| 1003 | rendered_prompt = system_template; |
| 1004 | } |
| 1005 | else |
| 1006 | { |
| 1007 | std::string user_template = "[INST] " + user_prompt + " [/INST]" ; |
| 1008 | rendered_prompt = user_template; |
| 1009 | } |
| 1010 | // encode the rendered prompt into tokens |
| 1011 | tokenizer.encode(text: rendered_prompt, bos: 1, eos: 0, tokens&: prompt_tokens, n_tokens&: num_prompt_tokens); |
| 1012 | user_idx = 0; // reset the user index |
| 1013 | user_turn = 0; |
| 1014 | std::cout << "Assistant: " ; |
| 1015 | } |
| 1016 | |
| 1017 | // determine the token to pass into the transformer next |
| 1018 | if (user_idx < num_prompt_tokens) |
| 1019 | { |
| 1020 | // if we are still processing the input prompt, force the next prompt token |
| 1021 | token = prompt_tokens[user_idx++]; |
| 1022 | } |
| 1023 | else |
| 1024 | { |
| 1025 | // otherwise use the next token sampled from previous turn |
| 1026 | token = next; |
| 1027 | } |
| 1028 | // EOS (=2) token ends the Assistant turn |
| 1029 | if (token == 2) |
| 1030 | { |
| 1031 | user_turn = 1; |
| 1032 | } |
| 1033 | |
| 1034 | // forward the transformer to get logits for the next token |
| 1035 | float *logits = transformer.forward(token, pos); |
| 1036 | next = sampler.sample(logits); |
| 1037 | pos++; |
| 1038 | |
| 1039 | if (user_idx >= num_prompt_tokens && next != 2) |
| 1040 | { |
| 1041 | // the Assistant is responding, so print its output |
| 1042 | std::string piece = tokenizer.decode(prev_token: token, token: next); |
| 1043 | safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes |
| 1044 | fflush(stream: stdout); |
| 1045 | } |
| 1046 | if (next == 2) |
| 1047 | { |
| 1048 | std::cout << "\n" ; |
| 1049 | } |
| 1050 | } |
| 1051 | std::cout << "\n" ; |
| 1052 | } |
| 1053 | |
| 1054 | int main(int argc, char *argv[]) |
| 1055 | { |
| 1056 | std::string checkpoint_path; |
| 1057 | std::string tokenizer_path = "/initrd/assets/tokenizer.bin" ; |
| 1058 | float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher |
| 1059 | float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower |
| 1060 | int steps = 256; // number of steps to run for |
| 1061 | std::string prompt; |
| 1062 | unsigned long long rng_seed = 0; // seed rng with time by default |
| 1063 | std::string mode = "generate" ; |
| 1064 | std::string system_prompt; |
| 1065 | if (argc >= 2) |
| 1066 | { |
| 1067 | checkpoint_path = argv[1]; |
| 1068 | } |
| 1069 | else |
| 1070 | { |
| 1071 | checkpoint_path = "/initrd/assets/stories15M.bin" ; |
| 1072 | } |
| 1073 | |
| 1074 | for (int i = 2; i < argc; i += 2) |
| 1075 | { |
| 1076 | // do some basic validation |
| 1077 | if (i + 1 >= argc) |
| 1078 | { |
| 1079 | error_usage(); |
| 1080 | } // must have arg after flag |
| 1081 | if (argv[i][0] != '-') |
| 1082 | { |
| 1083 | error_usage(); |
| 1084 | } // must start with dash |
| 1085 | if (strlen(s: argv[i]) != 2) |
| 1086 | { |
| 1087 | error_usage(); |
| 1088 | } // must be -x (one dash, one letter) |
| 1089 | // read in the args |
| 1090 | if (argv[i][1] == 't') |
| 1091 | { |
| 1092 | temperature = atof(string: argv[i + 1]); |
| 1093 | } |
| 1094 | else if (argv[i][1] == 'p') |
| 1095 | { |
| 1096 | topp = atof(string: argv[i + 1]); |
| 1097 | } |
| 1098 | else if (argv[i][1] == 's') |
| 1099 | { |
| 1100 | rng_seed = atoi(string: argv[i + 1]); |
| 1101 | } |
| 1102 | else if (argv[i][1] == 'n') |
| 1103 | { |
| 1104 | steps = atoi(string: argv[i + 1]); |
| 1105 | } |
| 1106 | else if (argv[i][1] == 'i') |
| 1107 | { |
| 1108 | prompt = argv[i + 1]; |
| 1109 | } |
| 1110 | else if (argv[i][1] == 'z') |
| 1111 | { |
| 1112 | tokenizer_path = argv[i + 1]; |
| 1113 | } |
| 1114 | else if (argv[i][1] == 'm') |
| 1115 | { |
| 1116 | mode = argv[i + 1]; |
| 1117 | } |
| 1118 | else if (argv[i][1] == 'y') |
| 1119 | { |
| 1120 | system_prompt = argv[i + 1]; |
| 1121 | } |
| 1122 | else |
| 1123 | { |
| 1124 | error_usage(); |
| 1125 | } |
| 1126 | } |
| 1127 | if (rng_seed <= 0) |
| 1128 | rng_seed = (unsigned int) time(NULL); |
| 1129 | if (temperature < 0.0) |
| 1130 | temperature = 0.0; |
| 1131 | if (topp < 0.0 || 1.0 < topp) |
| 1132 | topp = 0.9; |
| 1133 | if (steps < 0) |
| 1134 | steps = 0; |
| 1135 | if (is_quantized_model(checkpoint_path)) |
| 1136 | { |
| 1137 | Transformer<QuantizedTensor> transformer; |
| 1138 | transformer.load_model(checkpoint_path); |
| 1139 | if (steps == 0 || steps > transformer.config.seq_len) |
| 1140 | steps = transformer.config.seq_len; // ovrerride to ~max length |
| 1141 | Tokenizer tokenizer; |
| 1142 | tokenizer.build_tokenizer(tokenizer_path, size_for_vacab: transformer.config.vocab_size); |
| 1143 | Sampler sampler; |
| 1144 | sampler.build_sampler(vocab_size: transformer.config.vocab_size, temperature, topp, rng_seed); |
| 1145 | |
| 1146 | // run! |
| 1147 | if (mode == "generate" ) |
| 1148 | { |
| 1149 | generate(transformer, tokenizer, sampler, prompt, steps); |
| 1150 | } |
| 1151 | else if (mode == "chat" ) |
| 1152 | { |
| 1153 | chat(transformer, tokenizer, sampler, cli_user_prompt&: prompt, cli_system_prompt&: system_prompt, steps); |
| 1154 | } |
| 1155 | else |
| 1156 | { |
| 1157 | std::cerr << "unknown mode: " << mode << "\n" << std::endl; |
| 1158 | error_usage(); |
| 1159 | } |
| 1160 | } |
| 1161 | else |
| 1162 | { |
| 1163 | Transformer<float> transformer; |
| 1164 | transformer.load_model(checkpoint_path); |
| 1165 | if (steps == 0 || steps > transformer.config.seq_len) |
| 1166 | steps = transformer.config.seq_len; // ovrerride to ~max length |
| 1167 | Tokenizer tokenizer; |
| 1168 | tokenizer.build_tokenizer(tokenizer_path, size_for_vacab: transformer.config.vocab_size); |
| 1169 | Sampler sampler; |
| 1170 | sampler.build_sampler(vocab_size: transformer.config.vocab_size, temperature, topp, rng_seed); |
| 1171 | |
| 1172 | // run! |
| 1173 | if (mode == "generate" ) |
| 1174 | { |
| 1175 | generate(transformer, tokenizer, sampler, prompt, steps); |
| 1176 | } |
| 1177 | else if (mode == "chat" ) |
| 1178 | { |
| 1179 | chat(transformer, tokenizer, sampler, cli_user_prompt&: prompt, cli_system_prompt&: system_prompt, steps); |
| 1180 | } |
| 1181 | else |
| 1182 | { |
| 1183 | std::cerr << "unknown mode: " << mode << "\n" << std::endl; |
| 1184 | error_usage(); |
| 1185 | } |
| 1186 | } |
| 1187 | |
| 1188 | return 0; |
| 1189 | } |
| 1190 | |