liblloyal 1.0.0
Composable primitives for llama.cpp inference
Loading...
Searching...
No Matches
sampler.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6#include "common.hpp"
7#include "logits.hpp"
8#include "tokenizer.hpp"
9#include <cstdint>
10#include <ctime>
11#include <llama/llama.h>
12#include <optional>
13#include <stdexcept>
14#include <type_traits>
15#include <vector>
16
30namespace lloyal::detail {
31
32// ===== OPTIONAL VALUE EXTRACTION =====
33
37template <class T> struct is_optional : std::false_type {};
38
39template <class T> struct is_optional<std::optional<T>> : std::true_type {};
40
50template <class X, class T> constexpr T as_value(const X &x, T def) {
51 if constexpr (is_optional<X>::value) {
52 return x.value_or(def);
53 } else {
54 return static_cast<T>(x);
55 }
56}
57
58} // namespace lloyal::detail
59
60namespace lloyal {
61
62// ===== SAMPLING PARAMS CONCEPT =====
63
72template <class P>
73concept SamplingParamsLike = requires(const P &p) {
74 p.temperature;
75 p.top_k;
76 p.top_p;
77 p.typical_p;
78 p.min_p;
79 p.penalty_repeat;
80 p.penalty_freq;
81 p.penalty_present;
82 p.penalty_last_n;
83 p.seed;
84 // Additional fields for future extensions:
85 // p.mirostat, p.mirostat_tau, p.mirostat_eta
86 // p.dry_multiplier, p.dry_base, p.dry_allowed_length, p.dry_penalty_last_n
87 // p.xtc_probability, p.xtc_threshold
88 // p.top_n_sigma
89};
90
91namespace sampler {
92
93// ===== GREEDY SAMPLING =====
94
110inline llama_token greedy(llama_context *ctx, const llama_vocab *vocab) {
111 LLOYAL_LOG_DEBUG("[sampler::greedy] Sampling next token");
112
113 if (!ctx) {
114 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: NULL context");
115 throw std::runtime_error("sampler::greedy - NULL context");
116 }
117
118 if (!vocab) {
119 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: NULL vocabulary");
120 throw std::runtime_error("sampler::greedy - NULL vocabulary");
121 }
122
123 // Get last-step logits (index -1)
124 // Per llama.cpp maintainers: only works if logits=true was set for that step
125 // lloyal::logits::get() handles null checking and throws descriptive errors
126 const float *logits = lloyal::logits::get(ctx, -1);
127
128 // Get vocabulary size
129 const int n_vocab = llama_vocab_n_tokens(vocab);
130 if (n_vocab <= 0) {
131 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: Invalid vocabulary size: %d",
132 n_vocab);
133 throw std::runtime_error("sampler::greedy - Invalid vocabulary size");
134 }
135
136 // Argmax: Find token with highest probability
137 int best_id = 0;
138 float best_score = logits[0];
139 for (int i = 1; i < n_vocab; ++i) {
140 if (logits[i] > best_score) {
141 best_score = logits[i];
142 best_id = i;
143 }
144 }
145
146 llama_token result = static_cast<llama_token>(best_id);
147 LLOYAL_LOG_DEBUG("[sampler::greedy] Sampled token: %d (score: %.4f)", result,
148 best_score);
149
150 return result;
151}
152
153// ===== PARAMETERIZED SAMPLING =====
154
178template <SamplingParamsLike P>
179inline llama_token sample_with_params(llama_context *ctx,
180 const llama_vocab *vocab, const P &params,
181 llama_sampler *grammarSampler = nullptr) {
182 using detail::as_value;
183
184 // Extract parameters with defaults (handles both T and std::optional<T>)
185 uint32_t seed =
186 as_value(params.seed, static_cast<uint32_t>(std::time(nullptr)));
187 int32_t top_k = as_value(params.top_k, static_cast<int32_t>(40));
188 float top_p = as_value(params.top_p, 0.95f);
189 float min_p = as_value(params.min_p, 0.05f);
190 float typical_p = as_value(params.typical_p, 1.0f);
191 float temperature = as_value(params.temperature, 0.8f);
192 int32_t penalty_last_n =
193 as_value(params.penalty_last_n, static_cast<int32_t>(64));
194 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
195 float penalty_freq = as_value(params.penalty_freq, 0.0f);
196 float penalty_present = as_value(params.penalty_present, 0.0f);
197
198 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Building sampler");
199 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] temperature=%.2f, "
200 "top_k=%d, top_p=%.2f, min_p=%.2f",
201 temperature, static_cast<int>(top_k), top_p, min_p);
202
203 if (!ctx) {
204 throw std::runtime_error("sampler::sample_with_params - NULL context");
205 }
206 if (!vocab) {
207 throw std::runtime_error("sampler::sample_with_params - NULL vocabulary");
208 }
209
210 // ROUTING DECISION: Grammar present → use grammar-aware sampling
211 // No grammar → use lightweight chain approach
212 if (grammarSampler) {
213 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Grammar sampler provided, "
214 "using grammar-constrained sampling");
215
216 // Get logits and build token data array
217 // lloyal::logits::get() handles null checking and throws descriptive errors
218 const float *logits = lloyal::logits::get(ctx, -1);
219
220 const int n_vocab = llama_vocab_n_tokens(vocab);
221 if (n_vocab <= 0) {
222 throw std::runtime_error(
223 "sampler::sample_with_params - Invalid vocabulary size");
224 }
225
226 // Build candidate array from logits
227 std::vector<llama_token_data> candidates(n_vocab);
228 for (int i = 0; i < n_vocab; i++) {
229 candidates[i] =
230 llama_token_data{static_cast<llama_token>(i), logits[i], 0.0f};
231 }
232
233 llama_token_data_array cur_p = {
234 candidates.data(), static_cast<size_t>(n_vocab),
235 -1, // selected (will be set by samplers)
236 false // sorted
237 };
238
239 // Build sampler chain (WITHOUT grammar - grammar applied separately)
240 llama_sampler_chain_params chain_params =
241 llama_sampler_chain_default_params();
242 chain_params.no_perf = true;
243 auto *chain = llama_sampler_chain_init(chain_params);
244
245 // Add samplers in order (penalties → top-k → typical-p → top-p → min-p →
246 // temperature → dist)
247 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
248 penalty_present != 0.0f) {
249 llama_sampler_chain_add(
250 chain, llama_sampler_init_penalties(penalty_last_n, penalty_repeat,
251 penalty_freq, penalty_present));
252 }
253 if (top_k > 0) {
254 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
255 }
256 if (typical_p < 1.0f) {
257 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
258 }
259 if (top_p < 1.0f) {
260 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
261 }
262 if (min_p > 0.0f) {
263 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
264 }
265 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
266 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));
267
268 // Apply grammar constraint FIRST (uses persistent parser state from
269 // context)
270 llama_sampler_apply(grammarSampler, &cur_p);
271
272 // Then apply chain
273 llama_sampler_apply(chain, &cur_p);
274
275 // Get selected token
276 if (cur_p.selected == -1) {
277 llama_sampler_free(chain);
278 throw std::runtime_error(
279 "No selected token during sampling - check sampling configuration");
280 }
281 llama_token result = cur_p.data[cur_p.selected].id;
282
283 // Clean up chain (grammar sampler is managed by context, not freed here)
284 llama_sampler_free(chain);
285
286 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Grammar-sampled token: %d",
287 result);
288 return result;
289 }
290
291 // NO GRAMMAR: Use lightweight chain approach
292 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] No grammar, using "
293 "lightweight chain approach");
294
295 // Get logits
296 // lloyal::logits::get() handles null checking and throws descriptive errors
297 const float *logits = lloyal::logits::get(ctx, -1);
298
299 const int n_vocab = llama_vocab_n_tokens(vocab);
300 if (n_vocab <= 0) {
301 throw std::runtime_error(
302 "sampler::sample_with_params - Invalid vocabulary size");
303 }
304
305 // Create llama.cpp sampler chain
306 llama_sampler_chain_params chain_params =
307 llama_sampler_chain_default_params();
308 chain_params.no_perf = true;
309
310 auto *sampler_chain = llama_sampler_chain_init(chain_params);
311 if (!sampler_chain) {
312 throw std::runtime_error(
313 "sampler::sample_with_params - Failed to create sampler chain");
314 }
315
316 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Sampler chain created, "
317 "adding samplers...");
318
319 // 1. Repetition penalties (if enabled)
320 if (penalty_repeat != 1.0f || penalty_freq != 0.0f ||
321 penalty_present != 0.0f) {
322 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + penalties "
323 "(repeat=%.2f, freq=%.2f, present=%.2f, last_n=%d)",
324 penalty_repeat, penalty_freq, penalty_present,
325 penalty_last_n);
326 llama_sampler_chain_add(sampler_chain, llama_sampler_init_penalties(
327 penalty_last_n, penalty_repeat,
328 penalty_freq, penalty_present));
329 }
330
331 // 2. Top-K sampling (if enabled)
332 if (top_k > 0) {
333 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + top_k (%d)",
334 static_cast<int>(top_k));
335 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_k(top_k));
336 }
337
338 // 3. Typical-P sampling (if enabled)
339 if (typical_p < 1.0f) {
340 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + typical_p (%.2f)",
341 typical_p);
342 llama_sampler_chain_add(sampler_chain,
343 llama_sampler_init_typical(typical_p, 1));
344 }
345
346 // 4. Top-P sampling (if enabled)
347 if (top_p < 1.0f) {
348 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + top_p (%.2f)", top_p);
349 llama_sampler_chain_add(sampler_chain, llama_sampler_init_top_p(top_p, 1));
350 }
351
352 // 5. Min-P sampling (if enabled)
353 if (min_p > 0.0f) {
354 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + min_p (%.2f)", min_p);
355 llama_sampler_chain_add(sampler_chain, llama_sampler_init_min_p(min_p, 1));
356 }
357
358 // 6. Temperature scaling
359 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + temperature (%.2f)",
360 temperature);
361 llama_sampler_chain_add(sampler_chain, llama_sampler_init_temp(temperature));
362
363 // 7. Final distribution sampler
364 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] + dist (seed=%u)", seed);
365 llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(seed));
366
367 // Sample from the chain
368 llama_token result = llama_sampler_sample(sampler_chain, ctx, -1);
369
370 // Free the sampler chain
371 llama_sampler_free(sampler_chain);
372
373 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] Sampled token: %d "
374 "(temp=%.2f, top_k=%d, top_p=%.2f, min_p=%.2f)",
375 result, temperature, static_cast<int>(top_k), top_p, min_p);
376
377 return result;
378}
379
380// ===== MODEL-ACCEPTING CONVENIENCE OVERLOADS =====
381//
382// These overloads accept llama_model* and handle vocab extraction internally.
383// They delegate to the vocab-accepting primitives above.
384//
385// Benefits:
386// - Eliminate boilerplate (vocab extraction) in calling code
387// - Reduce code duplication across projects
388// - Backwards compatible - existing code unchanged
389
400inline llama_token greedy(llama_context *ctx, const llama_model *model) {
401 if (!model) {
402 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: model is null");
403 throw std::runtime_error("sampler::greedy - NULL model");
404 }
405
406 const llama_vocab *vocab = lloyal::tokenizer::get_vocab(model);
407 if (!vocab) {
408 LLOYAL_LOG_DEBUG("[sampler::greedy] ERROR: get_vocab returned null");
409 throw std::runtime_error(
410 "sampler::greedy - Failed to get vocab from model");
411 }
412
413 return greedy(ctx, vocab);
414}
415
428template <SamplingParamsLike P>
429inline llama_token sample_with_params(llama_context *ctx,
430 const llama_model *model, const P &params,
431 llama_sampler *grammarSampler = nullptr) {
432 if (!model) {
433 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] ERROR: model is null");
434 throw std::runtime_error("sampler::sample_with_params - NULL model");
435 }
436
437 const llama_vocab *vocab = lloyal::tokenizer::get_vocab(model);
438 if (!vocab) {
439 LLOYAL_LOG_DEBUG("[sampler::sample_with_params] ERROR: get_vocab "
440 "returned null");
441 throw std::runtime_error(
442 "sampler::sample_with_params - Failed to get vocab from model");
443 }
444
445 return sample_with_params(ctx, vocab, params, grammarSampler);
446}
447
448// ===== PERSISTENT SAMPLER CHAIN PRIMITIVES =====
449//
450// These functions manage persistent sampler chains for use cases like
451// branch-based MCTS where the same chain is reused across multiple samples.
452// Unlike sample_with_params() which creates/destroys a chain per call,
453// these allow chain reuse for better performance.
454
464template <SamplingParamsLike P>
465inline llama_sampler* create_chain(const P& params) {
466 using detail::as_value;
467
468 // Extract parameters with defaults
469 int32_t penalty_last_n = as_value(params.penalty_last_n, static_cast<int32_t>(64));
470 float penalty_repeat = as_value(params.penalty_repeat, 1.0f);
471 float penalty_freq = as_value(params.penalty_freq, 0.0f);
472 float penalty_present = as_value(params.penalty_present, 0.0f);
473 int32_t top_k = as_value(params.top_k, static_cast<int32_t>(40));
474 float top_p = as_value(params.top_p, 0.95f);
475 float min_p = as_value(params.min_p, 0.05f);
476 float typical_p = as_value(params.typical_p, 1.0f);
477 float temperature = as_value(params.temperature, 0.8f);
478 uint32_t seed = as_value(params.seed, static_cast<uint32_t>(std::time(nullptr)));
479
480 llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
481 chain_params.no_perf = true;
482 llama_sampler* chain = llama_sampler_chain_init(chain_params);
483
484 if (!chain) {
485 throw std::runtime_error("sampler::create_chain - Failed to create sampler chain");
486 }
487
488 // Add samplers in standard order: penalties → filters → temperature/greedy → dist
489 if (penalty_repeat != 1.0f || penalty_freq != 0.0f || penalty_present != 0.0f) {
490 llama_sampler_chain_add(chain,
491 llama_sampler_init_penalties(penalty_last_n, penalty_repeat, penalty_freq, penalty_present));
492 }
493 if (top_k > 0) {
494 llama_sampler_chain_add(chain, llama_sampler_init_top_k(top_k));
495 }
496 if (typical_p < 1.0f) {
497 llama_sampler_chain_add(chain, llama_sampler_init_typical(typical_p, 1));
498 }
499 if (top_p < 1.0f) {
500 llama_sampler_chain_add(chain, llama_sampler_init_top_p(top_p, 1));
501 }
502 if (min_p > 0.0f) {
503 llama_sampler_chain_add(chain, llama_sampler_init_min_p(min_p, 1));
504 }
505
506 // Temperature handling: temp <= 0 means greedy (deterministic argmax)
507 if (temperature <= 0.0f) {
508 llama_sampler_chain_add(chain, llama_sampler_init_greedy());
509 } else {
510 llama_sampler_chain_add(chain, llama_sampler_init_temp(temperature));
511 llama_sampler_chain_add(chain, llama_sampler_init_dist(seed));
512 }
513
514 return chain;
515}
516
526inline llama_sampler* clone_chain(llama_sampler* chain) {
527 if (!chain) {
528 return nullptr;
529 }
530 return llama_sampler_clone(chain);
531}
532
542inline void reseed_chain(llama_sampler* chain, uint32_t new_seed) {
543 if (!chain) {
544 return;
545 }
546
547 int n = llama_sampler_chain_n(chain);
548 if (n == 0) {
549 return;
550 }
551
552 // Remove last sampler (dist by convention)
553 llama_sampler* old_dist = llama_sampler_chain_remove(chain, n - 1);
554 if (old_dist) {
555 llama_sampler_free(old_dist);
556 }
557
558 // Add new dist with new seed
559 llama_sampler_chain_add(chain, llama_sampler_init_dist(new_seed));
560}
561
567inline void free_chain(llama_sampler* chain) {
568 if (chain) {
569 llama_sampler_free(chain);
570 }
571}
572
581inline void apply(llama_sampler* chain, llama_token_data_array* cur_p) {
582 if (chain && cur_p) {
583 llama_sampler_apply(chain, cur_p);
584 }
585}
586
595inline void accept(llama_sampler* chain, llama_token token) {
596 if (chain) {
597 llama_sampler_accept(chain, token);
598 }
599}
600
601} // namespace sampler
602} // namespace lloyal
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:47
C++20 concept: Any type with sampling parameter fields.
Definition sampler.hpp:73
Zero-copy logits access with clear lifetime semantics.
constexpr T as_value(const X &x, T def)
Extract value from either T or std::optional<T> with fallback.
Definition sampler.hpp:50
float * get(llama_context *ctx, int32_t step=-1)
Get raw logits pointer (zero-copy)
Definition logits.hpp:60
void apply(llama_sampler *chain, llama_token_data_array *cur_p)
Apply a sampler chain to a candidate array.
Definition sampler.hpp:581
void reseed_chain(llama_sampler *chain, uint32_t new_seed)
Reseed the dist sampler in a chain.
Definition sampler.hpp:542
void accept(llama_sampler *chain, llama_token token)
Accept a token into the sampler chain.
Definition sampler.hpp:595
llama_sampler * clone_chain(llama_sampler *chain)
Clone a sampler chain.
Definition sampler.hpp:526
llama_token sample_with_params(llama_context *ctx, const llama_vocab *vocab, const P &params, llama_sampler *grammarSampler=nullptr)
Sample with configurable parameters (template accepts any SamplingParams type)
Definition sampler.hpp:179
llama_sampler * create_chain(const P &params)
Create a persistent sampler chain from parameters.
Definition sampler.hpp:465
void free_chain(llama_sampler *chain)
Free a sampler chain.
Definition sampler.hpp:567
llama_token greedy(llama_context *ctx, const llama_vocab *vocab)
Greedy sampling: Select token with highest probability.
Definition sampler.hpp:110
const llama_vocab * get_vocab(const llama_model *model)
Get vocabulary from model.
JSON Schema to Grammar Converter (Header-Only)
Definition minja.hpp:575
Type trait to detect std::optional<T>
Definition sampler.hpp:37
Text Tokenization Operations.