Zero-copy access to logits from the last decode call.
Zero-copy access to logits from the last decode callReturns a pointer to the internal llama.cpp logits buffer. This is a zero-copy operation — no data is copied.
#pragma once
#include <llama/llama.h>
#include <cstring>
#include <span>
#include <stdexcept>
#include <string>
#include <vector>
inline float*
get(llama_context* ctx, int32_t index = -1) {
if (!ctx) {
throw std::runtime_error("logits::get - NULL context");
}
float* ptr = llama_get_logits_ith(ctx, index);
if (!ptr) {
throw std::runtime_error(
"logits::get - Failed to get logits at index " +
std::to_string(index) + ". "
"Ensure decode() was called with logits=true for this index."
);
}
return ptr;
}
llama_context* ctx,
const std::vector<std::span<const llama_token>>& prompts,
std::vector<float*>& output,
int32_t n_vocab) {
if (!ctx) throw std::runtime_error("logits::process_chunks - NULL context");
if (prompts.size() != output.size())
throw std::runtime_error("logits::process_chunks - prompts/output size mismatch");
if (prompts.empty()) return;
const int32_t seq_max = static_cast<int32_t>(llama_n_seq_max(ctx));
const int32_t batch_limit = static_cast<int32_t>(llama_n_batch(ctx));
const int32_t n = static_cast<int32_t>(prompts.size());
thread_local decode::Scratch scratch;
for (int32_t group_start = 0; group_start < n; group_start += seq_max) {
int32_t group_size = std::min(seq_max, n - group_start);
if (chunks.empty()) continue;
for (const auto& chunk : chunks) {
if (chunk.oversized) {
int32_t gi = group_start + chunk.indices[0];
llama_seq_id seq = static_cast<llama_seq_id>(chunk.indices[0]);
static_cast<int32_t>(prompts[gi].size()),
0, batch_limit, seq) != 0)
throw std::runtime_error("logits::process_chunks - decode::many failed");
std::memcpy(output[gi],
get(ctx, -1), n_vocab *
sizeof(
float));
continue;
}
std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
for (size_t k = 0; k < chunk.indices.size(); ++k) {
int32_t gi = group_start + chunk.indices[k];
scatter_items[k].tokens = prompts[gi];
scatter_items[k].start_pos = 0;
scatter_items[k].seq_id = static_cast<llama_seq_id>(chunk.indices[k]);
scatter_items[k].output_logits = true;
}
static_cast<int32_t>(scatter_items.size()),
scratch) != 0)
throw std::runtime_error("logits::process_chunks - decode::scatter failed");
int32_t cursor = 0;
for (size_t k = 0; k < scatter_items.size(); ++k) {
int32_t gi = group_start + chunk.indices[k];
int32_t item_n = static_cast<int32_t>(scatter_items[k].tokens.size());
std::memcpy(output[gi],
get(ctx, cursor + item_n - 1), n_vocab *
sizeof(
float));
cursor += item_n;
}
}
for (int32_t s = 0; s < group_size; ++s) {
}
}
}
}
Batch Decoding Operations.
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
bool remove_range(llama_context *ctx, llama_seq_id seq, llama_pos p0, llama_pos p1)
Remove token range from KV cache sequence.
void process_chunks(llama_context *ctx, const std::vector< std::span< const llama_token > > &prompts, std::vector< float * > &output, int32_t n_vocab)
Process arbitrary number of complete prompts for logit extraction.