181 const llama_vocab *vocab,
const P ¶ms,
182 llama_sampler *grammarSampler =
nullptr) {
187 as_value(params.seed,
static_cast<uint32_t
>(std::time(
nullptr)));
188 int32_t top_k = as_value(params.top_k,
static_cast<int32_t
>(40));
189 float top_p = as_value(params.top_p, 0.95f);
190 float min_p = as_value(params.min_p, 0.05f);
191 float typical_p = as_value(params.typical_p, 1.0f);
192 float temperature = as_value(params.temperature, 0.8f);
193 int32_t penalty_last_n =
194 as_value(params.penalty_last_n,
static_cast<int32_t
>(64));
195 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
196 float penalty_freq = as_value(params.penalty_freq, 0.0f);
197 float penalty_present = as_value(params.penalty_present, 0.0f);
201 "top_k=%d, top_p=%.2f, min_p=%.2f",
202 temperature,
static_cast<int>(top_k), top_p, min_p);
205 throw std::runtime_error(
"sampler::sample_with_params - NULL context");
208 throw std::runtime_error(
"sampler::sample_with_params - NULL vocabulary");
213 if (grammarSampler) {
215 "using grammar-constrained sampling");
221 const int n_vocab = llama_vocab_n_tokens(vocab);
223 throw std::runtime_error(
224 "sampler::sample_with_params - Invalid vocabulary size");
228 std::vector<llama_token_data> candidates(n_vocab);
229 for (
int i = 0; i < n_vocab; i++) {
231 llama_token_data{
static_cast<llama_token
>(i), logits[i], 0.0f};
234 llama_token_data_array cur_p = {
235 candidates.data(),
static_cast<size_t>(n_vocab),
241 llama_sampler_chain_params chain_params =
242 llama_sampler_chain_default_params();
243 chain_params.no_perf =
true;
244 auto *chain = llama_sampler_chain_init(chain_params);
248 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
249 penalty_present != 0.0f) {
250 llama_sampler_chain_add(
251 chain, llama_sampler_init_penalties(penalty_last_n, penalty_repeat,
252 penalty_freq, penalty_present));
255 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
257 if (typical_p < 1.0f) {
258 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
261 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
264 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
266 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
267 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));
271 llama_sampler_apply(grammarSampler, &cur_p);
274 llama_sampler_apply(chain, &cur_p);
277 if (cur_p.selected == -1) {
278 llama_sampler_free(chain);
279 throw std::runtime_error(
280 "No selected token during sampling - check sampling configuration");
282 llama_token result = cur_p.data[cur_p.selected].id;
285 llama_sampler_free(chain);
294 "lightweight chain approach");
300 const int n_vocab = llama_vocab_n_tokens(vocab);
302 throw std::runtime_error(
303 "sampler::sample_with_params - Invalid vocabulary size");
307 llama_sampler_chain_params chain_params =
308 llama_sampler_chain_default_params();
309 chain_params.no_perf =
true;
311 auto *sampler_chain = llama_sampler_chain_init(chain_params);
312 if (!sampler_chain) {
313 throw std::runtime_error(
314 "sampler::sample_with_params - Failed to create sampler chain");
318 "adding samplers...");
321 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
322 penalty_present != 0.0f) {
324 "(repeat=%.2f, freq=%.2f, present=%.2f, last_n=%d)",
325 penalty_repeat, penalty_freq, penalty_present,
327 llama_sampler_chain_add(sampler_chain, llama_sampler_init_penalties(
328 penalty_last_n, penalty_repeat,
329 penalty_freq, penalty_present));
335 static_cast<int>(top_k));
336 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_k(top_k));
340 if (typical_p < 1.0f) {
343 llama_sampler_chain_add(sampler_chain,
344 llama_sampler_init_typical(typical_p, 1));
350 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_p(top_p, 1));
356 llama_sampler_chain_add(sampler_chain, llama_sampler_init_min_p(min_p, 1));
362 llama_sampler_chain_add(sampler_chain, llama_sampler_init_temp(temperature));
366 llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(seed));
369 llama_token result = llama_sampler_sample(sampler_chain, ctx, -1);
372 llama_sampler_free(sampler_chain);
375 "(temp=%.2f, top_k=%d, top_p=%.2f, min_p=%.2f)",
376 result, temperature,
static_cast<int>(top_k), top_p, min_p);
470 int32_t penalty_last_n = as_value(params.penalty_last_n,
static_cast<int32_t
>(64));
471 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
472 float penalty_freq = as_value(params.penalty_freq, 0.0f);
473 float penalty_present = as_value(params.penalty_present, 0.0f);
474 int32_t top_k = as_value(params.top_k,
static_cast<int32_t
>(40));
475 float top_p = as_value(params.top_p, 0.95f);
476 float min_p = as_value(params.min_p, 0.05f);
477 float typical_p = as_value(params.typical_p, 1.0f);
478 float temperature = as_value(params.temperature, 0.8f);
479 uint32_t seed = as_value(params.seed,
static_cast<uint32_t
>(std::time(
nullptr)));
481 llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
482 chain_params.no_perf =
true;
483 llama_sampler* chain = llama_sampler_chain_init(chain_params);
486 throw std::runtime_error(
"sampler::create_chain - Failed to create sampler chain");
490 if (penalty_repeat != 1.0f || penalty_freq != 0.0f || penalty_present != 0.0f) {
491 llama_sampler_chain_add(chain,
492 llama_sampler_init_penalties(penalty_last_n, penalty_repeat, penalty_freq, penalty_present));
495 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
497 if (typical_p < 1.0f) {
498 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
501 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
504 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
508 if (temperature <= 0.0f) {
509 llama_sampler_chain_add(chain, llama_sampler_init_greedy());
511 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
512 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));