10#include <llama/llama.h>
124[[nodiscard]]
inline int many(llama_context *ctx,
const llama_token *tokens,
125 int32_t n_tokens, int32_t n_past, int32_t n_batch,
126 llama_seq_id seq_id = 0) {
128 "[decode::many] Processing %d tokens at position %d", n_tokens,
133 throw std::runtime_error(
"decode::many - NULL context");
136 if (!tokens || n_tokens <= 0) {
138 throw std::runtime_error(
"decode::many - Invalid token array");
142 throw std::runtime_error(
"decode::many - n_batch must be positive");
146 struct ThreadLocalBatch {
148 int32_t capacity = 0;
150 void ensure(int32_t n) {
151 if (n <= capacity)
return;
152 if (capacity > 0) llama_batch_free(batch);
153 batch = llama_batch_init(n, 0, 1);
157 ~ThreadLocalBatch() {
158 if (capacity > 0) llama_batch_free(batch);
161 thread_local ThreadLocalBatch tl;
163 llama_batch& batch = tl.batch;
166 int32_t processed = 0;
167 while (processed < n_tokens) {
168 const int32_t n_eval = std::min(n_tokens - processed, n_batch);
171 common_batch_clear(batch);
174 const bool is_last_chunk = (processed + n_eval >= n_tokens);
175 for (int32_t i = 0; i < n_eval; ++i) {
176 const int32_t pos = n_past + i;
177 const bool want_logits = is_last_chunk && (i == n_eval - 1);
182 common_batch_add(batch, tokens[processed + i], pos, {seq_id}, want_logits);
186 const int rc = llama_decode(ctx, batch);
189 "[decode::many] ERROR: llama_decode failed at position %d (rc=%d)",
198 processed, n_tokens);
206[[nodiscard]]
inline int many(llama_context *ctx,
207 const std::vector<llama_token> &tokens,
208 int32_t n_past, int32_t n_batch,
209 llama_seq_id seq_id = 0) {
210 return many(ctx, tokens.data(),
static_cast<int32_t
>(tokens.size()), n_past,
238[[nodiscard]]
inline int one(llama_context *ctx, llama_token tok, llama_pos pos,
239 llama_seq_id seq_id = 0,
bool want_logits =
true) {
241 throw std::runtime_error(
"decode::one - NULL context");
244 struct ThreadLocalBatch {
245 llama_batch batch = llama_batch_init(1, 0, 1);
246 ~ThreadLocalBatch() { llama_batch_free(batch); }
248 thread_local ThreadLocalBatch tl;
250 common_batch_clear(tl.batch);
251 common_batch_add(tl.batch, tok, pos, {seq_id}, want_logits);
253 return llama_decode(ctx, tl.batch);
311 batch.n_tokens = n_tokens;
313 batch.embd =
nullptr;
314 batch.pos =
pos_.data();
339[[nodiscard]]
inline int each(llama_context* ctx,
344 throw std::runtime_error(
"decode::each - NULL context");
347 throw std::runtime_error(
"decode::each - negative item count");
349 if (n == 0)
return 0;
353 for (int32_t i = 0; i < n; ++i) {
355 scratch.
pos_[i] = items[i].
pos;
362 llama_batch batch = scratch.
as_batch(n);
364 LLOYAL_LOG_DEBUG(
"[decode::each] Submitting %d tokens across %d sequences", n, n);
366 return llama_decode(ctx, batch);
370[[nodiscard]]
inline int each(llama_context* ctx,
371 const std::vector<EachItem>& items,
373 return each(ctx, items.data(),
static_cast<int32_t
>(items.size()), scratch);
395[[nodiscard]]
inline int scatter(llama_context* ctx,
400 throw std::runtime_error(
"decode::scatter - NULL context");
403 throw std::runtime_error(
"decode::scatter - negative item count");
407 for (int32_t i = 0; i < n; ++i) {
408 total +=
static_cast<int32_t
>(items[i].
tokens.size());
410 if (total == 0)
return 0;
415 for (int32_t i = 0; i < n; ++i) {
416 const auto& item = items[i];
417 const llama_pos base_pos = item.
start_pos;
418 const int32_t item_n =
static_cast<int32_t
>(item.tokens.size());
420 for (int32_t j = 0; j < item_n; ++j) {
421 scratch.
tokens_[cursor] = item.tokens[j];
422 scratch.
pos_[cursor] = base_pos + j;
427 const bool want_logits =
428 item.output_logits ? (j == item_n - 1) :
false;
429 scratch.
logits_[cursor] = want_logits ? int8_t{1} : int8_t{0};
435 llama_batch batch = scratch.
as_batch(total);
437 LLOYAL_LOG_DEBUG(
"[decode::scatter] Submitting %d total tokens across %d sequences", total, n);
439 return llama_decode(ctx, batch);
443[[nodiscard]]
inline int scatter(llama_context* ctx,
444 const std::vector<ScatterItem>& items,
446 return scatter(ctx, items.data(),
static_cast<int32_t
>(items.size()), scratch);
481 const std::span<const llama_token>* items,
485 std::vector<PackedChunk> chunks;
486 int32_t chunk_total = 0;
488 for (int32_t i = 0; i < n; ++i) {
489 int32_t tc =
static_cast<int32_t
>(items[i].size());
490 if (tc == 0)
continue;
493 chunks.push_back({{i},
true});
497 if (chunks.empty() || chunks.back().oversized ||
498 chunk_total + tc > n_batch) {
499 chunks.push_back({{i},
false});
502 chunks.back().indices.push_back(i);
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
Input item for decode::each — one token for one sequence.
llama_token token
Token to decode.
bool output_logits
Whether to compute logits after this token.
llama_seq_id seq_id
Target sequence ID.
llama_pos pos
KV cache position for this token.
A chunk of item indices produced by bin_pack()
bool oversized
True → single item exceeding n_batch.
std::vector< int32_t > indices
Indices into the original items array.
Input item for decode::scatter — multiple tokens for one sequence.
llama_pos start_pos
KV cache position for first token.
bool output_logits
When true, compute logits for last token in this run.
std::span< const llama_token > tokens
Token array (non-owning view)
llama_seq_id seq_id
Target sequence ID.
Reusable scratch buffers for multi-sequence batch construction.
std::vector< llama_seq_id > seq_id_single_
std::vector< llama_token > tokens_
std::vector< int32_t > n_seq_id_
std::vector< int8_t > logits_
std::vector< llama_seq_id * > seq_id_ptrs_
llama_batch as_batch(int32_t n_tokens)
ABI-sensitive: writes llama_batch fields directly (no common_batch_* wrapper exists for external-buff...
std::vector< llama_pos > pos_