liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
decode.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6#include "common.hpp"
7#include <common.h> // llama.cpp common library: common_batch_clear, common_batch_add
8#include <algorithm>
9#include <cstdint>
10#include <llama/llama.h>
11#include <span>
12#include <stdexcept>
13#include <vector>
14
72namespace lloyal::decode {
73
124[[nodiscard]] inline int many(llama_context *ctx, const llama_token *tokens,
125 int32_t n_tokens, int32_t n_past, int32_t n_batch,
126 llama_seq_id seq_id = 0) {
128 "[decode::many] Processing %d tokens at position %d", n_tokens,
129 n_past);
130
131 if (!ctx) {
132 LLOYAL_LOG_DEBUG("[decode::many] ERROR: NULL context");
133 throw std::runtime_error("decode::many - NULL context");
134 }
135
136 if (!tokens || n_tokens <= 0) {
137 LLOYAL_LOG_DEBUG("[decode::many] ERROR: Invalid token array");
138 throw std::runtime_error("decode::many - Invalid token array");
139 }
140
141 if (n_batch <= 0) {
142 throw std::runtime_error("decode::many - n_batch must be positive");
143 }
144
145 // Thread-local batch avoids per-call allocation. Grows if needed, never shrinks.
146 struct ThreadLocalBatch {
147 llama_batch batch{};
148 int32_t capacity = 0;
149
150 void ensure(int32_t n) {
151 if (n <= capacity) return;
152 if (capacity > 0) llama_batch_free(batch);
153 batch = llama_batch_init(n, 0, 1);
154 capacity = n;
155 }
156
157 ~ThreadLocalBatch() {
158 if (capacity > 0) llama_batch_free(batch);
159 }
160 };
161 thread_local ThreadLocalBatch tl;
162 tl.ensure(n_batch);
163 llama_batch& batch = tl.batch;
164
165 // Process tokens in chunks
166 int32_t processed = 0;
167 while (processed < n_tokens) {
168 const int32_t n_eval = std::min(n_tokens - processed, n_batch);
169
170 // Clear batch using llama.cpp common library
171 common_batch_clear(batch);
172
173 // Add tokens one by one, mark logits=true only on the final chunk's last token
174 const bool is_last_chunk = (processed + n_eval >= n_tokens);
175 for (int32_t i = 0; i < n_eval; ++i) {
176 const int32_t pos = n_past + i;
177 const bool want_logits = is_last_chunk && (i == n_eval - 1);
178
179 // Add token via llama.cpp common library (function-call ABI).
180 // {seq_id} constructs a temporary vector per token — acceptable cost
181 // vs direct field writes which create struct-layout ABI coupling.
182 common_batch_add(batch, tokens[processed + i], pos, {seq_id}, want_logits);
183 }
184
185 // Decode chunk (updates KV cache)
186 const int rc = llama_decode(ctx, batch);
187 if (rc != 0) {
189 "[decode::many] ERROR: llama_decode failed at position %d (rc=%d)",
190 n_past, rc);
191 return rc;
192 }
193
194 n_past += n_eval;
195 processed += n_eval;
196
197 LLOYAL_LOG_DEBUG("[decode::many] Processed %d/%d tokens",
198 processed, n_tokens);
199 }
200
201 LLOYAL_LOG_DEBUG("[decode::many] Decode complete");
202 return 0;
203}
204
206[[nodiscard]] inline int many(llama_context *ctx,
207 const std::vector<llama_token> &tokens,
208 int32_t n_past, int32_t n_batch,
209 llama_seq_id seq_id = 0) {
210 return many(ctx, tokens.data(), static_cast<int32_t>(tokens.size()), n_past,
211 n_batch, seq_id);
212}
213
238[[nodiscard]] inline int one(llama_context *ctx, llama_token tok, llama_pos pos,
239 llama_seq_id seq_id = 0, bool want_logits = true) {
240 if (!ctx) {
241 throw std::runtime_error("decode::one - NULL context");
242 }
243
244 struct ThreadLocalBatch {
245 llama_batch batch = llama_batch_init(1, 0, 1);
246 ~ThreadLocalBatch() { llama_batch_free(batch); }
247 };
248 thread_local ThreadLocalBatch tl;
249
250 common_batch_clear(tl.batch);
251 common_batch_add(tl.batch, tok, pos, {seq_id}, want_logits);
252
253 return llama_decode(ctx, tl.batch);
254}
255
256// ============================================================================
257// Multi-Sequence Decode
258// ============================================================================
259
269
278 std::span<const llama_token> tokens;
279 llama_pos start_pos;
280 llama_seq_id seq_id;
281 bool output_logits = false;
282};
283
290struct Scratch {
291 std::vector<llama_token> tokens_;
292 std::vector<llama_pos> pos_;
293 std::vector<int32_t> n_seq_id_;
294 std::vector<llama_seq_id> seq_id_single_;
295 std::vector<llama_seq_id*> seq_id_ptrs_;
296 std::vector<int8_t> logits_;
297
298 void resize(int32_t n) {
299 tokens_.resize(n);
300 pos_.resize(n);
301 n_seq_id_.resize(n);
302 seq_id_single_.resize(n);
303 seq_id_ptrs_.resize(n);
304 logits_.resize(n);
305 }
306
309 llama_batch as_batch(int32_t n_tokens) {
310 llama_batch batch{};
311 batch.n_tokens = n_tokens;
312 batch.token = tokens_.data();
313 batch.embd = nullptr;
314 batch.pos = pos_.data();
315 batch.n_seq_id = n_seq_id_.data();
316 batch.seq_id = seq_id_ptrs_.data();
317 batch.logits = logits_.data();
318 return batch;
319 }
320};
321
339[[nodiscard]] inline int each(llama_context* ctx,
340 const EachItem* items,
341 int32_t n,
342 Scratch& scratch) {
343 if (!ctx) {
344 throw std::runtime_error("decode::each - NULL context");
345 }
346 if (n < 0) {
347 throw std::runtime_error("decode::each - negative item count");
348 }
349 if (n == 0) return 0;
350
351 scratch.resize(n);
352
353 for (int32_t i = 0; i < n; ++i) {
354 scratch.tokens_[i] = items[i].token;
355 scratch.pos_[i] = items[i].pos;
356 scratch.n_seq_id_[i] = 1;
357 scratch.seq_id_single_[i] = items[i].seq_id;
358 scratch.seq_id_ptrs_[i] = &scratch.seq_id_single_[i];
359 scratch.logits_[i] = items[i].output_logits ? int8_t{1} : int8_t{0};
360 }
361
362 llama_batch batch = scratch.as_batch(n);
363
364 LLOYAL_LOG_DEBUG("[decode::each] Submitting %d tokens across %d sequences", n, n);
365
366 return llama_decode(ctx, batch);
367}
368
370[[nodiscard]] inline int each(llama_context* ctx,
371 const std::vector<EachItem>& items,
372 Scratch& scratch) {
373 return each(ctx, items.data(), static_cast<int32_t>(items.size()), scratch);
374}
375
395[[nodiscard]] inline int scatter(llama_context* ctx,
396 const ScatterItem* items,
397 int32_t n,
398 Scratch& scratch) {
399 if (!ctx) {
400 throw std::runtime_error("decode::scatter - NULL context");
401 }
402 if (n < 0) {
403 throw std::runtime_error("decode::scatter - negative item count");
404 }
405
406 int32_t total = 0;
407 for (int32_t i = 0; i < n; ++i) {
408 total += static_cast<int32_t>(items[i].tokens.size());
409 }
410 if (total == 0) return 0;
411
412 scratch.resize(total);
413
414 int32_t cursor = 0;
415 for (int32_t i = 0; i < n; ++i) {
416 const auto& item = items[i];
417 const llama_pos base_pos = item.start_pos;
418 const int32_t item_n = static_cast<int32_t>(item.tokens.size());
419
420 for (int32_t j = 0; j < item_n; ++j) {
421 scratch.tokens_[cursor] = item.tokens[j];
422 scratch.pos_[cursor] = base_pos + j;
423 scratch.n_seq_id_[cursor] = 1;
424 scratch.seq_id_single_[cursor] = item.seq_id;
425 scratch.seq_id_ptrs_[cursor] = &scratch.seq_id_single_[cursor];
426
427 const bool want_logits =
428 item.output_logits ? (j == item_n - 1) : false;
429 scratch.logits_[cursor] = want_logits ? int8_t{1} : int8_t{0};
430
431 ++cursor;
432 }
433 }
434
435 llama_batch batch = scratch.as_batch(total);
436
437 LLOYAL_LOG_DEBUG("[decode::scatter] Submitting %d total tokens across %d sequences", total, n);
438
439 return llama_decode(ctx, batch);
440}
441
443[[nodiscard]] inline int scatter(llama_context* ctx,
444 const std::vector<ScatterItem>& items,
445 Scratch& scratch) {
446 return scatter(ctx, items.data(), static_cast<int32_t>(items.size()), scratch);
447}
448
449// ============================================================================
450// Bin-Packing Utility
451// ============================================================================
452
461 std::vector<int32_t> indices;
462 bool oversized = false;
463};
464
480inline std::vector<PackedChunk> bin_pack(
481 const std::span<const llama_token>* items,
482 int32_t n,
483 int32_t n_batch) {
484
485 std::vector<PackedChunk> chunks;
486 int32_t chunk_total = 0;
487
488 for (int32_t i = 0; i < n; ++i) {
489 int32_t tc = static_cast<int32_t>(items[i].size());
490 if (tc == 0) continue;
491
492 if (tc > n_batch) {
493 chunks.push_back({{i}, true});
494 continue;
495 }
496
497 if (chunks.empty() || chunks.back().oversized ||
498 chunk_total + tc > n_batch) {
499 chunks.push_back({{i}, false});
500 chunk_total = tc;
501 } else {
502 chunks.back().indices.push_back(i);
503 chunk_total += tc;
504 }
505 }
506
507 return chunks;
508}
509
510} // namespace lloyal::decode
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:47
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
Definition decode.hpp:480
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
Definition decode.hpp:238
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
Definition decode.hpp:124
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
Definition decode.hpp:395
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
Definition decode.hpp:339
Input item for decode::each — one token for one sequence.
Definition decode.hpp:263
llama_token token
Token to decode.
Definition decode.hpp:264
bool output_logits
Whether to compute logits after this token.
Definition decode.hpp:267
llama_seq_id seq_id
Target sequence ID.
Definition decode.hpp:266
llama_pos pos
KV cache position for this token.
Definition decode.hpp:265
A chunk of item indices produced by bin_pack()
Definition decode.hpp:460
bool oversized
True → single item exceeding n_batch.
Definition decode.hpp:462
std::vector< int32_t > indices
Indices into the original items array.
Definition decode.hpp:461
Input item for decode::scatter — multiple tokens for one sequence.
Definition decode.hpp:277
llama_pos start_pos
KV cache position for first token.
Definition decode.hpp:279
bool output_logits
When true, compute logits for last token in this run.
Definition decode.hpp:281
std::span< const llama_token > tokens
Token array (non-owning view)
Definition decode.hpp:278
llama_seq_id seq_id
Target sequence ID.
Definition decode.hpp:280
Reusable scratch buffers for multi-sequence batch construction.
Definition decode.hpp:290
std::vector< llama_seq_id > seq_id_single_
Definition decode.hpp:294
std::vector< llama_token > tokens_
Definition decode.hpp:291
std::vector< int32_t > n_seq_id_
Definition decode.hpp:293
std::vector< int8_t > logits_
Definition decode.hpp:296
void resize(int32_t n)
Definition decode.hpp:298
std::vector< llama_seq_id * > seq_id_ptrs_
Definition decode.hpp:295
llama_batch as_batch(int32_t n_tokens)
ABI-sensitive: writes llama_batch fields directly (no common_batch_* wrapper exists for external-buff...
Definition decode.hpp:309
std::vector< llama_pos > pos_
Definition decode.hpp:292