liblloyal 1.0.0
Composable primitives for llama.cpp inference
Loading...
Searching...
No Matches
decoder.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6#include "common.hpp"
7#include "helpers.hpp"
8#include <algorithm>
9#include <cstdint>
10#include <llama/llama.h>
11#include <stdexcept>
12#include <vector>
13
31#ifndef LLOYAL_STACK_BATCH
32#define LLOYAL_STACK_BATCH 1
33#endif
34
45namespace lloyal::detail {
50struct BatchGuard {
51 llama_batch &batch;
52 explicit BatchGuard(llama_batch &b) : batch(b) {}
53 ~BatchGuard() { llama_batch_free(batch); }
54};
55
59inline void add_tokens_to_batch(llama_batch &batch, const llama_token *tokens,
60 int32_t start_idx, int32_t n_eval,
61 int32_t n_past, int32_t capacity,
62 llama_seq_id seq_id = 0) {
63 // Clear batch using helpers.hpp function
65
66 // Add tokens one by one, mark logits=true on LAST token only
67 for (int32_t i = 0; i < n_eval; ++i) {
68 const int32_t pos = n_past + i;
69 const bool want_logits = (i == n_eval - 1);
70
71 // Add token to specified sequence
72 lloyal::batch_add(batch, tokens[start_idx + i], pos, {seq_id}, want_logits,
73 capacity);
74 }
75}
76} // namespace lloyal::detail
77
78namespace lloyal::decoder {
79
127inline void decode_tokens(llama_context *ctx, const llama_token *tokens,
128 int32_t n_tokens, int32_t n_past, int32_t n_batch,
129 llama_seq_id seq_id = 0) {
131 "[decoder::decode_tokens] Processing %d tokens at position %d", n_tokens,
132 n_past);
133
134 if (!ctx) {
135 LLOYAL_LOG_DEBUG("[decoder::decode_tokens] ERROR: NULL context");
136 throw std::runtime_error("decoder::decode_tokens - NULL context");
137 }
138
139 if (!tokens || n_tokens <= 0) {
140 LLOYAL_LOG_DEBUG("[decoder::decode_tokens] ERROR: Invalid token array");
141 throw std::runtime_error("decoder::decode_tokens - Invalid token array");
142 }
143
144 // Initialize batch with RAII cleanup
145 // Single-sequence batch (n_seq_max = 1)
146 llama_batch batch = llama_batch_init(n_batch, 0, 1);
147 detail::BatchGuard batch_guard(batch);
148
149 // Process tokens in chunks
150 int32_t processed = 0;
151 while (processed < n_tokens) {
152 const int32_t n_eval = std::min(n_tokens - processed, n_batch);
153
154 // Add chunk to batch
155 detail::add_tokens_to_batch(batch, tokens, processed, n_eval, n_past,
156 n_batch, seq_id);
157
158 // Decode chunk (updates KV cache)
159 if (llama_decode(ctx, batch) != 0) {
161 "[decoder::decode_tokens] ERROR: llama_decode failed at position %d",
162 n_past);
163 throw std::runtime_error("decoder::decode_tokens - llama_decode failed");
164 }
165
166 n_past += n_eval;
167 processed += n_eval;
168
169 LLOYAL_LOG_DEBUG("[decoder::decode_tokens] Processed %d/%d tokens",
170 processed, n_tokens);
171 }
172
173 LLOYAL_LOG_DEBUG("[decoder::decode_tokens] Decode complete");
174}
175
179inline void decode_tokens(llama_context *ctx,
180 const std::vector<llama_token> &tokens,
181 int32_t n_past, int32_t n_batch,
182 llama_seq_id seq_id = 0) {
183 decode_tokens(ctx, tokens.data(), static_cast<int32_t>(tokens.size()), n_past,
184 n_batch, seq_id);
185}
186
202inline void decode_one(llama_context *ctx, llama_token tok, llama_pos pos,
203 llama_seq_id seq_id = 0, bool want_logits = true) {
204 if (!ctx) {
205 throw std::runtime_error("decoder::decode_one - NULL context");
206 }
207
208#if LLOYAL_STACK_BATCH
209 // Fast path: zero-allocation stack-constructed batch
210 // WARNING: ABI-fragile - breaks if llama_batch struct layout changes
211 llama_token tok_arr[1] = {tok};
212 llama_pos pos_arr[1] = {pos};
213 int32_t n_seq_id_arr[1] = {1};
214 llama_seq_id seq_arr[1] = {seq_id};
215 llama_seq_id *seq_ptrs[1] = {seq_arr};
216 int8_t logits_arr[1] = {static_cast<int8_t>(want_logits)};
217
218 llama_batch batch{};
219 batch.n_tokens = 1;
220 batch.token = tok_arr;
221 batch.embd = nullptr;
222 batch.pos = pos_arr;
223 batch.n_seq_id = n_seq_id_arr;
224 batch.seq_id = seq_ptrs;
225 batch.logits = logits_arr;
226#else
227 // Safe path: thread_local batch via llama.cpp's own initializer
228 // Handles any new fields with defaults, survives ABI changes
229 thread_local llama_batch batch = llama_batch_init(1, 0, 1);
230
231 batch.n_tokens = 1;
232 batch.token[0] = tok;
233 batch.pos[0] = pos;
234 batch.n_seq_id[0] = 1;
235 batch.seq_id[0][0] = seq_id;
236 batch.logits[0] = static_cast<int8_t>(want_logits);
237#endif
238
239 if (llama_decode(ctx, batch) != 0) {
240 throw std::runtime_error("decoder::decode_one - llama_decode failed");
241 }
242}
243
244} // namespace lloyal::decoder
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:47
Helper Utilities.
void decode_one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token with zero heap allocation (when LLOYAL_STACK_BATCH=1)
Definition decoder.hpp:202
void decode_tokens(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Process tokens through model to update KV cache.
Definition decoder.hpp:127
void add_tokens_to_batch(llama_batch &batch, const llama_token *tokens, int32_t start_idx, int32_t n_eval, int32_t n_past, int32_t capacity, llama_seq_id seq_id=0)
Add tokens to batch with position info.
Definition decoder.hpp:59
void batch_clear(llama_batch &batch)
Clear batch to empty state.
Definition helpers.hpp:64
void batch_add(llama_batch &batch, llama_token id, int32_t pos, const std::vector< llama_seq_id > &seq_ids, bool logits, int32_t capacity=-1)
Add single token to batch with position and sequence info.
Definition helpers.hpp:84
RAII guard for automatic batch cleanup Ensures llama_batch_free is called even if exceptions occur.
Definition decoder.hpp:50
BatchGuard(llama_batch &b)
Definition decoder.hpp:52