liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
sampler.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6
7#include "common.hpp"
8#include "logits.hpp"
9#include "tokenizer.hpp"
10#include <cstdint>
11#include <ctime>
12#include <llama/llama.h>
13#include <optional>
14#include <stdexcept>
15#include <type_traits>
16#include <vector>
17
31namespace lloyal::detail {
32
33// ===== OPTIONAL VALUE EXTRACTION =====
34
38template <class T> struct is_optional : std::false_type {};
39
40template <class T> struct is_optional<std::optional<T>> : std::true_type {};
41
51template <class X, class T> constexpr T as_value(const X &x, T def) {
52 if constexpr (is_optional<X>::value) {
53 return x.value_or(def);
54 } else {
55 return static_cast<T>(x);
56 }
57}
58
59} // namespace lloyal::detail
60
61namespace lloyal {
62
63// ===== SAMPLING PARAMS CONCEPT =====
64
73template <class P>
74concept SamplingParamsLike = requires(const P &p) {
75 p.temperature;
76 p.top_k;
77 p.top_p;
78 p.typical_p;
79 p.min_p;
80 p.penalty_repeat;
81 p.penalty_freq;
82 p.penalty_present;
83 p.penalty_last_n;
84 p.seed;
85 // Additional fields for future extensions:
86 // p.mirostat, p.mirostat_tau, p.mirostat_eta
87 // p.dry_multiplier, p.dry_base, p.dry_allowed_length, p.dry_penalty_last_n
88 // p.xtc_probability, p.xtc_threshold
89 // p.top_n_sigma
90};
91
92namespace sampler {
93
94// ===== GREEDY SAMPLING =====
95
111inline llama_token greedy(llama_context *ctx, const llama_vocab *vocab) {
112 LLOYAL_LOG_DEBUG("[sampler::greedy] Sampling next token");
113
114 if (!ctx) {
115 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: NULL context");
116 throw std::runtime_error("sampler::greedy - NULL context");
117 }
118
119 if (!vocab) {
120 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: NULL vocabulary");
121 throw std::runtime_error("sampler::greedy - NULL vocabulary");
122 }
123
124 // Get last-step logits (index -1)
125 // Per llama.cpp maintainers: only works if logits=true was set for that step
126 // lloyal::logits::get() handles null checking and throws descriptive errors
127 const float *logits = lloyal::logits::get(ctx, -1);
128
129 // Get vocabulary size
130 const int n_vocab = llama_vocab_n_tokens(vocab);
131 if (n_vocab <= 0) {
132 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: Invalid vocabulary size: %d",
133 n_vocab);
134 throw std::runtime_error("sampler::greedy - Invalid vocabulary size");
135 }
136
137 // Argmax: Find token with highest probability
138 int best_id = 0;
139 float best_score = logits[0];
140 for (int i = 1; i < n_vocab; ++i) {
141 if (logits[i] > best_score) {
142 best_score = logits[i];
143 best_id = i;
144 }
145 }
146
147 llama_token result = static_cast<llama_token>(best_id);
148 LLOYAL_LOG_DEBUG("[sampler::greedy] Sampled token: %d (score: %.4f)", result,
149 best_score);
150
151 return result;
152}
153
154// ===== PARAMETERIZED SAMPLING =====
155
179template <SamplingParamsLike P>
180inline llama_token sample_with_params(llama_context *ctx,
181 const llama_vocab *vocab, const P &params,
182 llama_sampler *grammarSampler = nullptr) {
183 using detail::as_value;
184
185 // Extract parameters with defaults (handles both T and std::optional<T>)
186 uint32_t seed =
187 as_value(params.seed, static_cast<uint32_t>(std::time(nullptr)));
188 int32_t top_k = as_value(params.top_k, static_cast<int32_t>(40));
189 float top_p = as_value(params.top_p, 0.95f);
190 float min_p = as_value(params.min_p, 0.05f);
191 float typical_p = as_value(params.typical_p, 1.0f);
192 float temperature = as_value(params.temperature, 0.8f);
193 int32_t penalty_last_n =
194 as_value(params.penalty_last_n, static_cast<int32_t>(64));
195 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
196 float penalty_freq = as_value(params.penalty_freq, 0.0f);
197 float penalty_present = as_value(params.penalty_present, 0.0f);
198
199 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Building sampler");
200 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] temperature=%.2f, "
201 "top_k=%d, top_p=%.2f, min_p=%.2f",
202 temperature, static_cast<int>(top_k), top_p, min_p);
203
204 if (!ctx) {
205 throw std::runtime_error("sampler::sample_with_params - NULL context");
206 }
207 if (!vocab) {
208 throw std::runtime_error("sampler::sample_with_params - NULL vocabulary");
209 }
210
211 // ROUTING DECISION: Grammar present → use grammar-aware sampling
212 // No grammar → use lightweight chain approach
213 if (grammarSampler) {
214 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Grammar sampler provided, "
215 "using grammar-constrained sampling");
216
217 // Get logits and build token data array
218 // lloyal::logits::get() handles null checking and throws descriptive errors
219 const float *logits = lloyal::logits::get(ctx, -1);
220
221 const int n_vocab = llama_vocab_n_tokens(vocab);
222 if (n_vocab <= 0) {
223 throw std::runtime_error(
224 "sampler::sample_with_params - Invalid vocabulary size");
225 }
226
227 // Build candidate array from logits
228 std::vector<llama_token_data> candidates(n_vocab);
229 for (int i = 0; i < n_vocab; i++) {
230 candidates[i] =
231 llama_token_data{static_cast<llama_token>(i), logits[i], 0.0f};
232 }
233
234 llama_token_data_array cur_p = {
235 candidates.data(), static_cast<size_t>(n_vocab),
236 -1, // selected (will be set by samplers)
237 false // sorted
238 };
239
240 // Build sampler chain (WITHOUT grammar - grammar applied separately)
241 llama_sampler_chain_params chain_params =
242 llama_sampler_chain_default_params();
243 chain_params.no_perf = true;
244 auto *chain = llama_sampler_chain_init(chain_params);
245
246 // Add samplers in order (penalties → top-k → typical-p → top-p → min-p →
247 // temperature → dist)
248 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
249 penalty_present != 0.0f) {
250 llama_sampler_chain_add(
251 chain, llama_sampler_init_penalties(penalty_last_n, penalty_repeat,
252 penalty_freq, penalty_present));
253 }
254 if (top_k > 0) {
255 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
256 }
257 if (typical_p < 1.0f) {
258 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
259 }
260 if (top_p < 1.0f) {
261 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
262 }
263 if (min_p > 0.0f) {
264 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
265 }
266 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
267 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));
268
269 // Apply grammar constraint FIRST (uses persistent parser state from
270 // context)
271 llama_sampler_apply(grammarSampler, &cur_p);
272
273 // Then apply chain
274 llama_sampler_apply(chain, &cur_p);
275
276 // Get selected token
277 if (cur_p.selected == -1) {
278 llama_sampler_free(chain);
279 throw std::runtime_error(
280 "No selected token during sampling - check sampling configuration");
281 }
282 llama_token result = cur_p.data[cur_p.selected].id;
283
284 // Clean up chain (grammar sampler is managed by context, not freed here)
285 llama_sampler_free(chain);
286
287 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Grammar-sampled token: %d",
288 result);
289 return result;
290 }
291
292 // NO GRAMMAR: Use lightweight chain approach
293 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] No grammar, using "
294 "lightweight chain approach");
295
296 // Get logits
297 // lloyal::logits::get() handles null checking and throws descriptive errors
298 const float *logits = lloyal::logits::get(ctx, -1);
299
300 const int n_vocab = llama_vocab_n_tokens(vocab);
301 if (n_vocab <= 0) {
302 throw std::runtime_error(
303 "sampler::sample_with_params - Invalid vocabulary size");
304 }
305
306 // Create llama.cpp sampler chain
307 llama_sampler_chain_params chain_params =
308 llama_sampler_chain_default_params();
309 chain_params.no_perf = true;
310
311 auto *sampler_chain = llama_sampler_chain_init(chain_params);
312 if (!sampler_chain) {
313 throw std::runtime_error(
314 "sampler::sample_with_params - Failed to create sampler chain");
315 }
316
317 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Sampler chain created, "
318 "adding samplers...");
319
320 // 1. Repetition penalties (if enabled)
321 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
322 penalty_present != 0.0f) {
323 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + penalties "
324 "(repeat=%.2f, freq=%.2f, present=%.2f, last_n=%d)",
325 penalty_repeat, penalty_freq, penalty_present,
326 penalty_last_n);
327 llama_sampler_chain_add(sampler_chain, llama_sampler_init_penalties(
328 penalty_last_n, penalty_repeat,
329 penalty_freq, penalty_present));
330 }
331
332 // 2. Top-K sampling (if enabled)
333 if (top_k > 0) {
334 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + top_k (%d)",
335 static_cast<int>(top_k));
336 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_k(top_k));
337 }
338
339 // 3. Typical-P sampling (if enabled)
340 if (typical_p < 1.0f) {
341 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + typical_p (%.2f)",
342 typical_p);
343 llama_sampler_chain_add(sampler_chain,
344 llama_sampler_init_typical(typical_p, 1));
345 }
346
347 // 4. Top-P sampling (if enabled)
348 if (top_p < 1.0f) {
349 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + top_p (%.2f)", top_p);
350 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_p(top_p, 1));
351 }
352
353 // 5. Min-P sampling (if enabled)
354 if (min_p > 0.0f) {
355 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + min_p (%.2f)", min_p);
356 llama_sampler_chain_add(sampler_chain, llama_sampler_init_min_p(min_p, 1));
357 }
358
359 // 6. Temperature scaling
360 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + temperature (%.2f)",
361 temperature);
362 llama_sampler_chain_add(sampler_chain, llama_sampler_init_temp(temperature));
363
364 // 7. Final distribution sampler
365 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + dist (seed=%u)", seed);
366 llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(seed));
367
368 // Sample from the chain
369 llama_token result = llama_sampler_sample(sampler_chain, ctx, -1);
370
371 // Free the sampler chain
372 llama_sampler_free(sampler_chain);
373
374 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Sampled token: %d "
375 "(temp=%.2f, top_k=%d, top_p=%.2f, min_p=%.2f)",
376 result, temperature, static_cast<int>(top_k), top_p, min_p);
377
378 return result;
379}
380
381// ===== MODEL-ACCEPTING CONVENIENCE OVERLOADS =====
382//
383// These overloads accept llama_model* and handle vocab extraction internally.
384// They delegate to the vocab-accepting primitives above.
385//
386// Benefits:
387// - Eliminate boilerplate (vocab extraction) in calling code
388// - Reduce code duplication across projects
389// - Backwards compatible - existing code unchanged
390
401inline llama_token greedy(llama_context *ctx, const llama_model *model) {
402 if (!model) {
403 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: model is null");
404 throw std::runtime_error("sampler::greedy - NULL model");
405 }
406
407 const llama_vocab *vocab = lloyal::tokenizer::get_vocab(model);
408 if (!vocab) {
409 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: get_vocab returned null");
410 throw std::runtime_error(
411 "sampler::greedy - Failed to get vocab from model");
412 }
413
414 return greedy(ctx, vocab);
415}
416
429template <SamplingParamsLike P>
430inline llama_token sample_with_params(llama_context *ctx,
431 const llama_model *model, const P &params,
432 llama_sampler *grammarSampler = nullptr) {
433 if (!model) {
434 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] ERROR: model is null");
435 throw std::runtime_error("sampler::sample_with_params - NULL model");
436 }
437
438 const llama_vocab *vocab = lloyal::tokenizer::get_vocab(model);
439 if (!vocab) {
440 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] ERROR: get_vocab "
441 "returned null");
442 throw std::runtime_error(
443 "sampler::sample_with_params - Failed to get vocab from model");
444 }
445
446 return sample_with_params(ctx, vocab, params, grammarSampler);
447}
448
449// ===== PERSISTENT SAMPLER CHAIN PRIMITIVES =====
450//
451// These functions manage persistent sampler chains for use cases like
452// branch-based tree search where the same chain is reused across multiple samples.
453// Unlike sample_with_params() which creates/destroys a chain per call,
454// these allow chain reuse for better performance.
455
465template <SamplingParamsLike P>
466inline llama_sampler* create_chain(const P& params) {
467 using detail::as_value;
468
469 // Extract parameters with defaults
470 int32_t penalty_last_n = as_value(params.penalty_last_n, static_cast<int32_t>(64));
471 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
472 float penalty_freq = as_value(params.penalty_freq, 0.0f);
473 float penalty_present = as_value(params.penalty_present, 0.0f);
474 int32_t top_k = as_value(params.top_k, static_cast<int32_t>(40));
475 float top_p = as_value(params.top_p, 0.95f);
476 float min_p = as_value(params.min_p, 0.05f);
477 float typical_p = as_value(params.typical_p, 1.0f);
478 float temperature = as_value(params.temperature, 0.8f);
479 uint32_t seed = as_value(params.seed, static_cast<uint32_t>(std::time(nullptr)));
480
481 llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
482 chain_params.no_perf = true;
483 llama_sampler* chain = llama_sampler_chain_init(chain_params);
484
485 if (!chain) {
486 throw std::runtime_error("sampler::create_chain - Failed to create sampler chain");
487 }
488
489 // Add samplers in standard order: penalties → filters → temperature/greedy → dist
490 if (penalty_repeat != 1.0f || penalty_freq != 0.0f || penalty_present != 0.0f) {
491 llama_sampler_chain_add(chain,
492 llama_sampler_init_penalties(penalty_last_n, penalty_repeat, penalty_freq, penalty_present));
493 }
494 if (top_k > 0) {
495 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
496 }
497 if (typical_p < 1.0f) {
498 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
499 }
500 if (top_p < 1.0f) {
501 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
502 }
503 if (min_p > 0.0f) {
504 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
505 }
506
507 // Temperature handling: temp <= 0 means greedy (deterministic argmax)
508 if (temperature <= 0.0f) {
509 llama_sampler_chain_add(chain, llama_sampler_init_greedy());
510 } else {
511 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
512 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));
513 }
514
515 return chain;
516}
517
527inline llama_sampler* clone_chain(llama_sampler* chain) {
528 if (!chain) {
529 return nullptr;
530 }
531 return llama_sampler_clone(chain);
532}
533
543inline void reseed_chain(llama_sampler* chain, uint32_t new_seed) {
544 if (!chain) {
545 return;
546 }
547
548 int n = llama_sampler_chain_n(chain);
549 if (n == 0) {
550 return;
551 }
552
553 // Remove last sampler (dist by convention)
554 llama_sampler* old_dist = llama_sampler_chain_remove(chain, n - 1);
555 if (old_dist) {
556 llama_sampler_free(old_dist);
557 }
558
559 // Add new dist with new seed
560 llama_sampler_chain_add(chain, llama_sampler_init_dist(new_seed));
561}
562
568inline void free_chain(llama_sampler* chain) {
569 if (chain) {
570 llama_sampler_free(chain);
571 }
572}
573
582inline void apply(llama_sampler* chain, llama_token_data_array* cur_p) {
583 if (chain && cur_p) {
584 llama_sampler_apply(chain, cur_p);
585 }
586}
587
596inline void accept(llama_sampler* chain, llama_token token) {
597 if (chain) {
598 llama_sampler_accept(chain, token);
599 }
600}
601
602} // namespace sampler
603} // namespace lloyal
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:48
C++20 concept: Any type with sampling parameter fields.
Definition sampler.hpp:74
Zero-copy logits access with clear lifetime semantics.
constexpr T as_value(const X &x, T def)
Extract value from either T or std::optional<T> with fallback.
Definition sampler.hpp:51
float * get(llama_context *ctx, int32_t index=-1)
Definition logits.hpp:79
void apply(llama_sampler *chain, llama_token_data_array *cur_p)
Apply a sampler chain to a candidate array.
Definition sampler.hpp:582
void reseed_chain(llama_sampler *chain, uint32_t new_seed)
Reseed the dist sampler in a chain.
Definition sampler.hpp:543
void accept(llama_sampler *chain, llama_token token)
Accept a token into the sampler chain.
Definition sampler.hpp:596
llama_sampler * clone_chain(llama_sampler *chain)
Clone a sampler chain.
Definition sampler.hpp:527
llama_token sample_with_params(llama_context *ctx, const llama_vocab *vocab, const P &params, llama_sampler *grammarSampler=nullptr)
Sample with configurable parameters (template accepts any SamplingParams type)
Definition sampler.hpp:180
llama_sampler * create_chain(const P &params)
Create a persistent sampler chain from parameters.
Definition sampler.hpp:466
void free_chain(llama_sampler *chain)
Free a sampler chain.
Definition sampler.hpp:568
llama_token greedy(llama_context *ctx, const llama_vocab *vocab)
Greedy sampling: Select token with highest probability.
Definition sampler.hpp:111
const llama_vocab * get_vocab(const llama_model *model)
Get vocabulary from model.
Boundary Tracker Stub for OSS liblloyal.
Type trait to detect std::optional<T>
Definition sampler.hpp:38
Text Tokenization Operations.