liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
logits.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6
28#include <llama/llama.h>
29#include <cstring>
30#include <span>
31#include <stdexcept>
32#include <string>
33#include <vector>
34
35#include "decode.hpp"
36#include "kv.hpp"
37
38namespace lloyal::logits {
39
79inline float* get(llama_context* ctx, int32_t index = -1) {
80 if (!ctx) {
81 throw std::runtime_error("logits::get - NULL context");
82 }
83
84 float* ptr = llama_get_logits_ith(ctx, index);
85 if (!ptr) {
86 throw std::runtime_error(
87 "logits::get - Failed to get logits at index " +
88 std::to_string(index) + ". "
89 "Ensure decode() was called with logits=true for this index."
90 );
91 }
92
93 return ptr;
94}
95
109inline void process_chunks(
110 llama_context* ctx,
111 const std::vector<std::span<const llama_token>>& prompts,
112 std::vector<float*>& output,
113 int32_t n_vocab) {
114
115 if (!ctx) throw std::runtime_error("logits::process_chunks - NULL context");
116 if (prompts.size() != output.size())
117 throw std::runtime_error("logits::process_chunks - prompts/output size mismatch");
118 if (prompts.empty()) return;
119
120 const int32_t seq_max = static_cast<int32_t>(llama_n_seq_max(ctx));
121 const int32_t batch_limit = static_cast<int32_t>(llama_n_batch(ctx));
122 const int32_t n = static_cast<int32_t>(prompts.size());
123 thread_local decode::Scratch scratch;
124
125 for (int32_t group_start = 0; group_start < n; group_start += seq_max) {
126 int32_t group_size = std::min(seq_max, n - group_start);
127
128 // bin_pack skips empties internally — pass group slice directly
129 auto chunks = decode::bin_pack(&prompts[group_start], group_size, batch_limit);
130 if (chunks.empty()) continue;
131
132 for (const auto& chunk : chunks) {
133 if (chunk.oversized) {
134 int32_t gi = group_start + chunk.indices[0];
135 llama_seq_id seq = static_cast<llama_seq_id>(chunk.indices[0]);
136
137 if (decode::many(ctx, prompts[gi].data(),
138 static_cast<int32_t>(prompts[gi].size()),
139 0, batch_limit, seq) != 0)
140 throw std::runtime_error("logits::process_chunks - decode::many failed");
141
142 std::memcpy(output[gi], get(ctx, -1), n_vocab * sizeof(float));
143 continue;
144 }
145
146 // Normal chunk — build ScatterItems
147 std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
148 for (size_t k = 0; k < chunk.indices.size(); ++k) {
149 int32_t gi = group_start + chunk.indices[k];
150 scatter_items[k].tokens = prompts[gi];
151 scatter_items[k].start_pos = 0;
152 scatter_items[k].seq_id = static_cast<llama_seq_id>(chunk.indices[k]);
153 scatter_items[k].output_logits = true;
154 }
155
156 if (decode::scatter(ctx, scatter_items.data(),
157 static_cast<int32_t>(scatter_items.size()),
158 scratch) != 0)
159 throw std::runtime_error("logits::process_chunks - decode::scatter failed");
160
161 // Capture logits
162 int32_t cursor = 0;
163 for (size_t k = 0; k < scatter_items.size(); ++k) {
164 int32_t gi = group_start + chunk.indices[k];
165 int32_t item_n = static_cast<int32_t>(scatter_items[k].tokens.size());
166 std::memcpy(output[gi], get(ctx, cursor + item_n - 1), n_vocab * sizeof(float));
167 cursor += item_n;
168 }
169 }
170
171 // Evict KV for this group's seq_ids (no-op on unused ids)
172 for (int32_t s = 0; s < group_size; ++s) {
173 kv::remove_range(ctx, static_cast<llama_seq_id>(s), 0, -1);
174 }
175 }
176}
177
178} // namespace lloyal::logits
Batch Decoding Operations.
KV Cache Physics.
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
Definition decode.hpp:481
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
Definition decode.hpp:125
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
Definition decode.hpp:396
bool remove_range(llama_context *ctx, llama_seq_id seq, llama_pos p0, llama_pos p1)
Remove token range from KV cache sequence.
Definition kv.hpp:78
float * get(llama_context *ctx, int32_t index=-1)
Definition logits.hpp:79
void process_chunks(llama_context *ctx, const std::vector< std::span< const llama_token > > &prompts, std::vector< float * > &output, int32_t n_vocab)
Process arbitrary number of complete prompts for logit extraction.
Definition logits.hpp:109
Reusable scratch buffers for multi-sequence batch construction.
Definition decode.hpp:291