10#include <llama/llama.h>
31#ifndef LLOYAL_STACK_BATCH
32#define LLOYAL_STACK_BATCH 1
60 int32_t start_idx, int32_t n_eval,
61 int32_t n_past, int32_t capacity,
62 llama_seq_id seq_id = 0) {
67 for (int32_t i = 0; i < n_eval; ++i) {
68 const int32_t pos = n_past + i;
69 const bool want_logits = (i == n_eval - 1);
128 int32_t n_tokens, int32_t n_past, int32_t n_batch,
129 llama_seq_id seq_id = 0) {
131 "[decoder::decode_tokens] Processing %d tokens at position %d", n_tokens,
136 throw std::runtime_error(
"decoder::decode_tokens - NULL context");
139 if (!tokens || n_tokens <= 0) {
141 throw std::runtime_error(
"decoder::decode_tokens - Invalid token array");
146 llama_batch batch = llama_batch_init(n_batch, 0, 1);
150 int32_t processed = 0;
151 while (processed < n_tokens) {
152 const int32_t n_eval = std::min(n_tokens - processed, n_batch);
159 if (llama_decode(ctx, batch) != 0) {
161 "[decoder::decode_tokens] ERROR: llama_decode failed at position %d",
163 throw std::runtime_error(
"decoder::decode_tokens - llama_decode failed");
170 processed, n_tokens);
180 const std::vector<llama_token> &tokens,
181 int32_t n_past, int32_t n_batch,
182 llama_seq_id seq_id = 0) {
183 decode_tokens(ctx, tokens.data(),
static_cast<int32_t
>(tokens.size()), n_past,
202inline void decode_one(llama_context *ctx, llama_token tok, llama_pos pos,
203 llama_seq_id seq_id = 0,
bool want_logits =
true) {
205 throw std::runtime_error(
"decoder::decode_one - NULL context");
208#if LLOYAL_STACK_BATCH
211 llama_token tok_arr[1] = {tok};
212 llama_pos pos_arr[1] = {pos};
213 int32_t n_seq_id_arr[1] = {1};
214 llama_seq_id seq_arr[1] = {seq_id};
215 llama_seq_id *seq_ptrs[1] = {seq_arr};
216 int8_t logits_arr[1] = {
static_cast<int8_t
>(want_logits)};
220 batch.token = tok_arr;
221 batch.embd =
nullptr;
223 batch.n_seq_id = n_seq_id_arr;
224 batch.seq_id = seq_ptrs;
225 batch.logits = logits_arr;
229 thread_local llama_batch batch = llama_batch_init(1, 0, 1);
232 batch.token[0] = tok;
234 batch.n_seq_id[0] = 1;
235 batch.seq_id[0][0] = seq_id;
236 batch.logits[0] =
static_cast<int8_t
>(want_logits);
239 if (llama_decode(ctx, batch) != 0) {
240 throw std::runtime_error(
"decoder::decode_one - llama_decode failed");
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
void decode_one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token with zero heap allocation (when LLOYAL_STACK_BATCH=1)
void decode_tokens(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Process tokens through model to update KV cache.
void add_tokens_to_batch(llama_batch &batch, const llama_token *tokens, int32_t start_idx, int32_t n_eval, int32_t n_past, int32_t capacity, llama_seq_id seq_id=0)
Add tokens to batch with position info.
void batch_clear(llama_batch &batch)
Clear batch to empty state.
void batch_add(llama_batch &batch, llama_token id, int32_t pos, const std::vector< llama_seq_id > &seq_ids, bool logits, int32_t capacity=-1)
Add single token to batch with position and sequence info.
RAII guard for automatic batch cleanup Ensures llama_batch_free is called even if exceptions occur.
BatchGuard(llama_batch &b)