180 const llama_vocab *vocab,
const P ¶ms,
181 llama_sampler *grammarSampler =
nullptr) {
186 as_value(params.seed,
static_cast<uint32_t
>(std::time(
nullptr)));
187 int32_t top_k = as_value(params.top_k,
static_cast<int32_t
>(40));
188 float top_p = as_value(params.top_p, 0.95f);
189 float min_p = as_value(params.min_p, 0.05f);
190 float typical_p = as_value(params.typical_p, 1.0f);
191 float temperature = as_value(params.temperature, 0.8f);
192 int32_t penalty_last_n =
193 as_value(params.penalty_last_n,
static_cast<int32_t
>(64));
194 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
195 float penalty_freq = as_value(params.penalty_freq, 0.0f);
196 float penalty_present = as_value(params.penalty_present, 0.0f);
200 "top_k=%d, top_p=%.2f, min_p=%.2f",
201 temperature,
static_cast<int>(top_k), top_p, min_p);
204 throw std::runtime_error(
"sampler::sample_with_params - NULL context");
207 throw std::runtime_error(
"sampler::sample_with_params - NULL vocabulary");
212 if (grammarSampler) {
214 "using grammar-constrained sampling");
220 const int n_vocab = llama_vocab_n_tokens(vocab);
222 throw std::runtime_error(
223 "sampler::sample_with_params - Invalid vocabulary size");
227 std::vector<llama_token_data> candidates(n_vocab);
228 for (
int i = 0; i < n_vocab; i++) {
230 llama_token_data{
static_cast<llama_token
>(i), logits[i], 0.0f};
233 llama_token_data_array cur_p = {
234 candidates.data(),
static_cast<size_t>(n_vocab),
240 llama_sampler_chain_params chain_params =
241 llama_sampler_chain_default_params();
242 chain_params.no_perf =
true;
243 auto *chain = llama_sampler_chain_init(chain_params);
247 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
248 penalty_present != 0.0f) {
249 llama_sampler_chain_add(
250 chain, llama_sampler_init_penalties(penalty_last_n, penalty_repeat,
251 penalty_freq, penalty_present));
254 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
256 if (typical_p < 1.0f) {
257 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
260 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
263 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
265 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
266 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));
270 llama_sampler_apply(grammarSampler, &cur_p);
273 llama_sampler_apply(chain, &cur_p);
276 if (cur_p.selected == -1) {
277 llama_sampler_free(chain);
278 throw std::runtime_error(
279 "No selected token during sampling - check sampling configuration");
281 llama_token result = cur_p.data[cur_p.selected].id;
284 llama_sampler_free(chain);
293 "lightweight chain approach");
299 const int n_vocab = llama_vocab_n_tokens(vocab);
301 throw std::runtime_error(
302 "sampler::sample_with_params - Invalid vocabulary size");
306 llama_sampler_chain_params chain_params =
307 llama_sampler_chain_default_params();
308 chain_params.no_perf =
true;
310 auto *sampler_chain = llama_sampler_chain_init(chain_params);
311 if (!sampler_chain) {
312 throw std::runtime_error(
313 "sampler::sample_with_params - Failed to create sampler chain");
317 "adding samplers...");
320 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
321 penalty_present != 0.0f) {
323 "(repeat=%.2f, freq=%.2f, present=%.2f, last_n=%d)",
324 penalty_repeat, penalty_freq, penalty_present,
326 llama_sampler_chain_add(sampler_chain, llama_sampler_init_penalties(
327 penalty_last_n, penalty_repeat,
328 penalty_freq, penalty_present));
334 static_cast<int>(top_k));
335 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_k(top_k));
339 if (typical_p < 1.0f) {
342 llama_sampler_chain_add(sampler_chain,
343 llama_sampler_init_typical(typical_p, 1));
349 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_p(top_p, 1));
355 llama_sampler_chain_add(sampler_chain, llama_sampler_init_min_p(min_p, 1));
361 llama_sampler_chain_add(sampler_chain, llama_sampler_init_temp(temperature));
365 llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(seed));
368 llama_token result = llama_sampler_sample(sampler_chain, ctx, -1);
371 llama_sampler_free(sampler_chain);
374 "(temp=%.2f, top_k=%d, top_p=%.2f, min_p=%.2f)",
375 result, temperature,
static_cast<int>(top_k), top_p, min_p);
469 int32_t penalty_last_n = as_value(params.penalty_last_n,
static_cast<int32_t
>(64));
470 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
471 float penalty_freq = as_value(params.penalty_freq, 0.0f);
472 float penalty_present = as_value(params.penalty_present, 0.0f);
473 int32_t top_k = as_value(params.top_k,
static_cast<int32_t
>(40));
474 float top_p = as_value(params.top_p, 0.95f);
475 float min_p = as_value(params.min_p, 0.05f);
476 float typical_p = as_value(params.typical_p, 1.0f);
477 float temperature = as_value(params.temperature, 0.8f);
478 uint32_t seed = as_value(params.seed,
static_cast<uint32_t
>(std::time(
nullptr)));
480 llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
481 chain_params.no_perf =
true;
482 llama_sampler* chain = llama_sampler_chain_init(chain_params);
485 throw std::runtime_error(
"sampler::create_chain - Failed to create sampler chain");
489 if (penalty_repeat != 1.0f || penalty_freq != 0.0f || penalty_present != 0.0f) {
490 llama_sampler_chain_add(chain,
491 llama_sampler_init_penalties(penalty_last_n, penalty_repeat, penalty_freq, penalty_present));
494 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
496 if (typical_p < 1.0f) {
497 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
500 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
503 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
507 if (temperature <= 0.0f) {
508 llama_sampler_chain_add(chain, llama_sampler_init_greedy());
510 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
511 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));