11#include <llama/llama.h>
125[[nodiscard]]
inline int many(llama_context *ctx,
const llama_token *tokens,
126 int32_t n_tokens, int32_t n_past, int32_t n_batch,
127 llama_seq_id seq_id = 0) {
129 "[decode::many] Processing %d tokens at position %d", n_tokens,
134 throw std::runtime_error(
"decode::many - NULL context");
137 if (!tokens || n_tokens <= 0) {
139 throw std::runtime_error(
"decode::many - Invalid token array");
143 throw std::runtime_error(
"decode::many - n_batch must be positive");
147 struct ThreadLocalBatch {
149 int32_t capacity = 0;
151 void ensure(int32_t n) {
152 if (n <= capacity)
return;
153 if (capacity > 0) llama_batch_free(batch);
154 batch = llama_batch_init(n, 0, 1);
158 ~ThreadLocalBatch() {
159 if (capacity > 0) llama_batch_free(batch);
162 thread_local ThreadLocalBatch tl;
164 llama_batch& batch = tl.batch;
167 int32_t processed = 0;
168 while (processed < n_tokens) {
169 const int32_t n_eval = std::min(n_tokens - processed, n_batch);
172 common_batch_clear(batch);
175 const bool is_last_chunk = (processed + n_eval >= n_tokens);
176 for (int32_t i = 0; i < n_eval; ++i) {
177 const int32_t pos = n_past + i;
178 const bool want_logits = is_last_chunk && (i == n_eval - 1);
183 common_batch_add(batch, tokens[processed + i], pos, {seq_id}, want_logits);
187 const int rc = llama_decode(ctx, batch);
190 "[decode::many] ERROR: llama_decode failed at position %d (rc=%d)",
199 processed, n_tokens);
207[[nodiscard]]
inline int many(llama_context *ctx,
208 const std::vector<llama_token> &tokens,
209 int32_t n_past, int32_t n_batch,
210 llama_seq_id seq_id = 0) {
211 return many(ctx, tokens.data(),
static_cast<int32_t
>(tokens.size()), n_past,
239[[nodiscard]]
inline int one(llama_context *ctx, llama_token tok, llama_pos pos,
240 llama_seq_id seq_id = 0,
bool want_logits =
true) {
242 throw std::runtime_error(
"decode::one - NULL context");
245 struct ThreadLocalBatch {
246 llama_batch batch = llama_batch_init(1, 0, 1);
247 ~ThreadLocalBatch() { llama_batch_free(batch); }
249 thread_local ThreadLocalBatch tl;
251 common_batch_clear(tl.batch);
252 common_batch_add(tl.batch, tok, pos, {seq_id}, want_logits);
254 return llama_decode(ctx, tl.batch);
312 batch.n_tokens = n_tokens;
314 batch.embd =
nullptr;
315 batch.pos =
pos_.data();
340[[nodiscard]]
inline int each(llama_context* ctx,
345 throw std::runtime_error(
"decode::each - NULL context");
348 throw std::runtime_error(
"decode::each - negative item count");
350 if (n == 0)
return 0;
354 for (int32_t i = 0; i < n; ++i) {
356 scratch.
pos_[i] = items[i].
pos;
363 llama_batch batch = scratch.
as_batch(n);
365 LLOYAL_LOG_DEBUG(
"[decode::each] Submitting %d tokens across %d sequences", n, n);
367 return llama_decode(ctx, batch);
371[[nodiscard]]
inline int each(llama_context* ctx,
372 const std::vector<EachItem>& items,
374 return each(ctx, items.data(),
static_cast<int32_t
>(items.size()), scratch);
396[[nodiscard]]
inline int scatter(llama_context* ctx,
401 throw std::runtime_error(
"decode::scatter - NULL context");
404 throw std::runtime_error(
"decode::scatter - negative item count");
408 for (int32_t i = 0; i < n; ++i) {
409 total +=
static_cast<int32_t
>(items[i].
tokens.size());
411 if (total == 0)
return 0;
416 for (int32_t i = 0; i < n; ++i) {
417 const auto& item = items[i];
418 const llama_pos base_pos = item.
start_pos;
419 const int32_t item_n =
static_cast<int32_t
>(item.tokens.size());
421 for (int32_t j = 0; j < item_n; ++j) {
422 scratch.
tokens_[cursor] = item.tokens[j];
423 scratch.
pos_[cursor] = base_pos + j;
428 const bool want_logits =
429 item.output_logits ? (j == item_n - 1) :
false;
430 scratch.
logits_[cursor] = want_logits ? int8_t{1} : int8_t{0};
436 llama_batch batch = scratch.
as_batch(total);
438 LLOYAL_LOG_DEBUG(
"[decode::scatter] Submitting %d total tokens across %d sequences", total, n);
440 return llama_decode(ctx, batch);
444[[nodiscard]]
inline int scatter(llama_context* ctx,
445 const std::vector<ScatterItem>& items,
447 return scatter(ctx, items.data(),
static_cast<int32_t
>(items.size()), scratch);
482 const std::span<const llama_token>* items,
486 std::vector<PackedChunk> chunks;
487 int32_t chunk_total = 0;
489 for (int32_t i = 0; i < n; ++i) {
490 int32_t tc =
static_cast<int32_t
>(items[i].size());
491 if (tc == 0)
continue;
494 chunks.push_back({{i},
true});
498 if (chunks.empty() || chunks.back().oversized ||
499 chunk_total + tc > n_batch) {
500 chunks.push_back({{i},
false});
503 chunks.back().indices.push_back(i);
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
Input item for decode::each — one token for one sequence.
llama_token token
Token to decode.
bool output_logits
Whether to compute logits after this token.
llama_seq_id seq_id
Target sequence ID.
llama_pos pos
KV cache position for this token.
A chunk of item indices produced by bin_pack()
bool oversized
True → single item exceeding n_batch.
std::vector< int32_t > indices
Indices into the original items array.
Input item for decode::scatter — multiple tokens for one sequence.
llama_pos start_pos
KV cache position for first token.
bool output_logits
When true, compute logits for last token in this run.
std::span< const llama_token > tokens
Token array (non-owning view)
llama_seq_id seq_id
Target sequence ID.
Reusable scratch buffers for multi-sequence batch construction.
std::vector< llama_seq_id > seq_id_single_
std::vector< llama_token > tokens_
std::vector< int32_t > n_seq_id_
std::vector< int8_t > logits_
std::vector< llama_seq_id * > seq_id_ptrs_
llama_batch as_batch(int32_t n_tokens)
ABI-sensitive: writes llama_batch fields directly (no common_batch_* wrapper exists for external-buff...
std::vector< llama_pos > pos_