liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
decode.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6
7#include "common.hpp"
8#include <common.h> // llama.cpp common library: common_batch_clear, common_batch_add
9#include <algorithm>
10#include <cstdint>
11#include <llama/llama.h>
12#include <span>
13#include <stdexcept>
14#include <vector>
15
73namespace lloyal::decode {
74
125[[nodiscard]] inline int many(llama_context *ctx, const llama_token *tokens,
126 int32_t n_tokens, int32_t n_past, int32_t n_batch,
127 llama_seq_id seq_id = 0) {
129 "[decode::many] Processing %d tokens at position %d", n_tokens,
130 n_past);
131
132 if (!ctx) {
133 LLOYAL_LOG_DEBUG("[decode::many] ERROR: NULL context");
134 throw std::runtime_error("decode::many - NULL context");
135 }
136
137 if (!tokens || n_tokens <= 0) {
138 LLOYAL_LOG_DEBUG("[decode::many] ERROR: Invalid token array");
139 throw std::runtime_error("decode::many - Invalid token array");
140 }
141
142 if (n_batch <= 0) {
143 throw std::runtime_error("decode::many - n_batch must be positive");
144 }
145
146 // Thread-local batch avoids per-call allocation. Grows if needed, never shrinks.
147 struct ThreadLocalBatch {
148 llama_batch batch{};
149 int32_t capacity = 0;
150
151 void ensure(int32_t n) {
152 if (n <= capacity) return;
153 if (capacity > 0) llama_batch_free(batch);
154 batch = llama_batch_init(n, 0, 1);
155 capacity = n;
156 }
157
158 ~ThreadLocalBatch() {
159 if (capacity > 0) llama_batch_free(batch);
160 }
161 };
162 thread_local ThreadLocalBatch tl;
163 tl.ensure(n_batch);
164 llama_batch& batch = tl.batch;
165
166 // Process tokens in chunks
167 int32_t processed = 0;
168 while (processed < n_tokens) {
169 const int32_t n_eval = std::min(n_tokens - processed, n_batch);
170
171 // Clear batch using llama.cpp common library
172 common_batch_clear(batch);
173
174 // Add tokens one by one, mark logits=true only on the final chunk's last token
175 const bool is_last_chunk = (processed + n_eval >= n_tokens);
176 for (int32_t i = 0; i < n_eval; ++i) {
177 const int32_t pos = n_past + i;
178 const bool want_logits = is_last_chunk && (i == n_eval - 1);
179
180 // Add token via llama.cpp common library (function-call ABI).
181 // {seq_id} constructs a temporary vector per token — acceptable cost
182 // vs direct field writes which create struct-layout ABI coupling.
183 common_batch_add(batch, tokens[processed + i], pos, {seq_id}, want_logits);
184 }
185
186 // Decode chunk (updates KV cache)
187 const int rc = llama_decode(ctx, batch);
188 if (rc != 0) {
190 "[decode::many] ERROR: llama_decode failed at position %d (rc=%d)",
191 n_past, rc);
192 return rc;
193 }
194
195 n_past += n_eval;
196 processed += n_eval;
197
198 LLOYAL_LOG_DEBUG("[decode::many] Processed %d/%d tokens",
199 processed, n_tokens);
200 }
201
202 LLOYAL_LOG_DEBUG("[decode::many] Decode complete");
203 return 0;
204}
205
207[[nodiscard]] inline int many(llama_context *ctx,
208 const std::vector<llama_token> &tokens,
209 int32_t n_past, int32_t n_batch,
210 llama_seq_id seq_id = 0) {
211 return many(ctx, tokens.data(), static_cast<int32_t>(tokens.size()), n_past,
212 n_batch, seq_id);
213}
214
239[[nodiscard]] inline int one(llama_context *ctx, llama_token tok, llama_pos pos,
240 llama_seq_id seq_id = 0, bool want_logits = true) {
241 if (!ctx) {
242 throw std::runtime_error("decode::one - NULL context");
243 }
244
245 struct ThreadLocalBatch {
246 llama_batch batch = llama_batch_init(1, 0, 1);
247 ~ThreadLocalBatch() { llama_batch_free(batch); }
248 };
249 thread_local ThreadLocalBatch tl;
250
251 common_batch_clear(tl.batch);
252 common_batch_add(tl.batch, tok, pos, {seq_id}, want_logits);
253
254 return llama_decode(ctx, tl.batch);
255}
256
257// ============================================================================
258// Multi-Sequence Decode
259// ============================================================================
260
270
279 std::span<const llama_token> tokens;
280 llama_pos start_pos;
281 llama_seq_id seq_id;
282 bool output_logits = false;
283};
284
291struct Scratch {
292 std::vector<llama_token> tokens_;
293 std::vector<llama_pos> pos_;
294 std::vector<int32_t> n_seq_id_;
295 std::vector<llama_seq_id> seq_id_single_;
296 std::vector<llama_seq_id*> seq_id_ptrs_;
297 std::vector<int8_t> logits_;
298
299 void resize(int32_t n) {
300 tokens_.resize(n);
301 pos_.resize(n);
302 n_seq_id_.resize(n);
303 seq_id_single_.resize(n);
304 seq_id_ptrs_.resize(n);
305 logits_.resize(n);
306 }
307
310 llama_batch as_batch(int32_t n_tokens) {
311 llama_batch batch{};
312 batch.n_tokens = n_tokens;
313 batch.token = tokens_.data();
314 batch.embd = nullptr;
315 batch.pos = pos_.data();
316 batch.n_seq_id = n_seq_id_.data();
317 batch.seq_id = seq_id_ptrs_.data();
318 batch.logits = logits_.data();
319 return batch;
320 }
321};
322
340[[nodiscard]] inline int each(llama_context* ctx,
341 const EachItem* items,
342 int32_t n,
343 Scratch& scratch) {
344 if (!ctx) {
345 throw std::runtime_error("decode::each - NULL context");
346 }
347 if (n < 0) {
348 throw std::runtime_error("decode::each - negative item count");
349 }
350 if (n == 0) return 0;
351
352 scratch.resize(n);
353
354 for (int32_t i = 0; i < n; ++i) {
355 scratch.tokens_[i] = items[i].token;
356 scratch.pos_[i] = items[i].pos;
357 scratch.n_seq_id_[i] = 1;
358 scratch.seq_id_single_[i] = items[i].seq_id;
359 scratch.seq_id_ptrs_[i] = &scratch.seq_id_single_[i];
360 scratch.logits_[i] = items[i].output_logits ? int8_t{1} : int8_t{0};
361 }
362
363 llama_batch batch = scratch.as_batch(n);
364
365 LLOYAL_LOG_DEBUG("[decode::each] Submitting %d tokens across %d sequences", n, n);
366
367 return llama_decode(ctx, batch);
368}
369
371[[nodiscard]] inline int each(llama_context* ctx,
372 const std::vector<EachItem>& items,
373 Scratch& scratch) {
374 return each(ctx, items.data(), static_cast<int32_t>(items.size()), scratch);
375}
376
396[[nodiscard]] inline int scatter(llama_context* ctx,
397 const ScatterItem* items,
398 int32_t n,
399 Scratch& scratch) {
400 if (!ctx) {
401 throw std::runtime_error("decode::scatter - NULL context");
402 }
403 if (n < 0) {
404 throw std::runtime_error("decode::scatter - negative item count");
405 }
406
407 int32_t total = 0;
408 for (int32_t i = 0; i < n; ++i) {
409 total += static_cast<int32_t>(items[i].tokens.size());
410 }
411 if (total == 0) return 0;
412
413 scratch.resize(total);
414
415 int32_t cursor = 0;
416 for (int32_t i = 0; i < n; ++i) {
417 const auto& item = items[i];
418 const llama_pos base_pos = item.start_pos;
419 const int32_t item_n = static_cast<int32_t>(item.tokens.size());
420
421 for (int32_t j = 0; j < item_n; ++j) {
422 scratch.tokens_[cursor] = item.tokens[j];
423 scratch.pos_[cursor] = base_pos + j;
424 scratch.n_seq_id_[cursor] = 1;
425 scratch.seq_id_single_[cursor] = item.seq_id;
426 scratch.seq_id_ptrs_[cursor] = &scratch.seq_id_single_[cursor];
427
428 const bool want_logits =
429 item.output_logits ? (j == item_n - 1) : false;
430 scratch.logits_[cursor] = want_logits ? int8_t{1} : int8_t{0};
431
432 ++cursor;
433 }
434 }
435
436 llama_batch batch = scratch.as_batch(total);
437
438 LLOYAL_LOG_DEBUG("[decode::scatter] Submitting %d total tokens across %d sequences", total, n);
439
440 return llama_decode(ctx, batch);
441}
442
444[[nodiscard]] inline int scatter(llama_context* ctx,
445 const std::vector<ScatterItem>& items,
446 Scratch& scratch) {
447 return scatter(ctx, items.data(), static_cast<int32_t>(items.size()), scratch);
448}
449
450// ============================================================================
451// Bin-Packing Utility
452// ============================================================================
453
462 std::vector<int32_t> indices;
463 bool oversized = false;
464};
465
481inline std::vector<PackedChunk> bin_pack(
482 const std::span<const llama_token>* items,
483 int32_t n,
484 int32_t n_batch) {
485
486 std::vector<PackedChunk> chunks;
487 int32_t chunk_total = 0;
488
489 for (int32_t i = 0; i < n; ++i) {
490 int32_t tc = static_cast<int32_t>(items[i].size());
491 if (tc == 0) continue;
492
493 if (tc > n_batch) {
494 chunks.push_back({{i}, true});
495 continue;
496 }
497
498 if (chunks.empty() || chunks.back().oversized ||
499 chunk_total + tc > n_batch) {
500 chunks.push_back({{i}, false});
501 chunk_total = tc;
502 } else {
503 chunks.back().indices.push_back(i);
504 chunk_total += tc;
505 }
506 }
507
508 return chunks;
509}
510
511} // namespace lloyal::decode
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:48
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
Definition decode.hpp:481
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
Definition decode.hpp:239
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
Definition decode.hpp:125
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
Definition decode.hpp:396
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
Definition decode.hpp:340
Input item for decode::each — one token for one sequence.
Definition decode.hpp:264
llama_token token
Token to decode.
Definition decode.hpp:265
bool output_logits
Whether to compute logits after this token.
Definition decode.hpp:268
llama_seq_id seq_id
Target sequence ID.
Definition decode.hpp:267
llama_pos pos
KV cache position for this token.
Definition decode.hpp:266
A chunk of item indices produced by bin_pack()
Definition decode.hpp:461
bool oversized
True → single item exceeding n_batch.
Definition decode.hpp:463
std::vector< int32_t > indices
Indices into the original items array.
Definition decode.hpp:462
Input item for decode::scatter — multiple tokens for one sequence.
Definition decode.hpp:278
llama_pos start_pos
KV cache position for first token.
Definition decode.hpp:280
bool output_logits
When true, compute logits for last token in this run.
Definition decode.hpp:282
std::span< const llama_token > tokens
Token array (non-owning view)
Definition decode.hpp:279
llama_seq_id seq_id
Target sequence ID.
Definition decode.hpp:281
Reusable scratch buffers for multi-sequence batch construction.
Definition decode.hpp:291
std::vector< llama_seq_id > seq_id_single_
Definition decode.hpp:295
std::vector< llama_token > tokens_
Definition decode.hpp:292
std::vector< int32_t > n_seq_id_
Definition decode.hpp:294
std::vector< int8_t > logits_
Definition decode.hpp:297
void resize(int32_t n)
Definition decode.hpp:299
std::vector< llama_seq_id * > seq_id_ptrs_
Definition decode.hpp:296
llama_batch as_batch(int32_t n_tokens)
ABI-sensitive: writes llama_batch fields directly (no common_batch_* wrapper exists for external-buff...
Definition decode.hpp:310
std::vector< llama_pos > pos_
Definition decode.hpp:293