Set static logit bias for specific tokens.
Set static logit bias for specific tokensAdds additive bias to token logits before sampling. Use -INFINITY to ban tokens. Replaces any existing biases. Logit bias is CLONED when forking.
#pragma once
#include <llama/llama.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <ctime>
#include <deque>
#include <functional>
#include <limits>
#include <mutex>
#include <span>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
struct KvPressure {
};
}
return static_cast<uint16_t
>(h >>
GEN_SHIFT);
}
return (
static_cast<uint32_t
>(generation) <<
GEN_SHIFT) | index;
}
struct SamplerChainEntry {
llama_sampler*
chain =
nullptr;
if (this != &o) {
}
return *this;
}
};
struct GrammarEntry {
if (this != &o) {
sampler = o.sampler; o.sampler =
nullptr;
}
return *this;
}
};
struct CachedSamplingParams {
bool operator==(
const CachedSamplingParams&)
const =
default;
};
template <SamplingParamsLike P>
using ::lloyal::detail::as_value;
return CachedSamplingParams{
as_value(p.temperature, 0.8f),
as_value(p.top_k, static_cast<int32_t>(40)),
as_value(p.top_p, 0.95f),
as_value(p.typical_p, 1.0f),
as_value(p.min_p, 0.05f),
as_value(p.penalty_repeat, 1.0f),
as_value(p.penalty_freq, 0.0f),
as_value(p.penalty_present, 0.0f),
as_value(p.penalty_last_n, static_cast<int32_t>(64)),
as_value(p.seed, static_cast<uint32_t>(0)),
};
}
struct BranchState {
llama_context*
ctx =
nullptr;
const llama_model*
model =
nullptr;
std::function<void(llama_token_data_array&)>
steer_fn;
};
struct DecodeEachItem {
};
struct DecodeScatterItem {
std::span<const llama_token>
tokens;
};
class BranchStore {
public:
explicit BranchStore(size_t initial_capacity = 16) {
if (initial_capacity < 2) {
initial_capacity = 2;
}
slots_.resize(initial_capacity);
slots_[0].in_use = true;
slots_[0].generation = 0xFFFF;
for (size_t i = initial_capacity; i-- > 1; ) {
freelist_.push_back(static_cast<uint16_t>(i));
}
}
if (tenancy_.
ctx !=
nullptr) {
}
for (size_t i = 1; i < slots_.size(); ++i) {
if (slots_[i].in_use) {
free_branch_resources(slots_[i]);
}
}
}
}
BranchState* st =
get(handle);
st->seq_id = seq;
return {handle, seq};
}
BranchState* st =
get(handle);
if (!st) return;
BranchState* p =
get(st->parent);
if (p) {
auto& c = p->children;
c.erase(std::remove(c.begin(), c.end(), handle), c.end());
}
}
if (st->position > st->fork_head) {
uint32_t unique = static_cast<uint32_t>(st->position - st->fork_head);
cells_used_ = (unique <= cells_used_) ? cells_used_ - unique : 0;
}
free_branch_resources(*st);
reset_slot(*st);
if (freelist_.size() == slots_.size() - 1) {
cells_used_ = 0;
}
}
cells_used_ = 0;
}
if (tenancy_.
ctx ==
nullptr)
return;
for (size_t i = 1; i < slots_.size(); ++i) {
if (slots_[i].in_use) {
free_branch_resources(slots_[i]);
reset_slot(slots_[i]);
}
}
cells_used_ = 0;
}
BranchState* w =
get(winner);
if (!w) throw std::runtime_error("retainOnly: invalid winner handle");
if (w->seq_id ==
NO_LEASE)
throw std::runtime_error(
"retainOnly: winner has no lease");
std::vector<BranchHandle> losers;
for (size_t i = 1; i < slots_.size(); ++i) {
if (!slots_[i].in_use) continue;
if (h == winner) continue;
losers.push_back(h);
}
for (auto h : losers)
release_slot_only(h);
w->fork_head = 0;
w->children.clear();
cells_used_ = static_cast<uint32_t>(w->position);
}
if (!tenancy_.
ctx)
return {0, 0, 0};
uint32_t n_ctx =
static_cast<uint32_t
>(llama_n_ctx(tenancy_.
ctx));
uint32_t remaining = (n_ctx > cells_used_) ? n_ctx - cells_used_ : 0;
return { n_ctx, cells_used_, remaining };
}
const BranchState* st =
get(h);
}
const BranchState* st =
get(h);
return st ? st->fork_head : 0;
}
static const std::vector<BranchHandle> empty;
const BranchState* st =
get(h);
return st ? st->children : empty;
}
const BranchState* st =
get(h);
return st ? st->children.empty() : true;
}
const BranchState* st =
get(h);
return st ? (st->seq_id !=
NO_LEASE) :
false;
}
if (index == 0) return nullptr;
if (index >= slots_.size()) return nullptr;
BranchState& slot = slots_[index];
if (!slot.in_use || slot.generation != gen) {
return nullptr;
}
return &slot;
}
return const_cast<BranchStore*
>(
this)->
get(handle);
}
template <SamplingParamsLike P>
SamplerChainEntry entry;
entry.has_dist = (temperature > 0.0f);
sampler_chains_.emplace(h, std::move(entry));
return h;
}
if (h == 0) return 0;
auto it = sampler_chains_.find(h);
if (it == sampler_chains_.end()) return 0;
SamplerChainEntry entry;
entry.has_dist = it->second.has_dist;
sampler_chains_.emplace(nh, std::move(entry));
return nh;
}
if (h != 0) sampler_chains_.erase(h);
}
if (h == 0) return nullptr;
auto it = sampler_chains_.find(h);
return it != sampler_chains_.end() ? it->second.chain : nullptr;
}
if (h == 0) return false;
auto it = sampler_chains_.find(h);
return it != sampler_chains_.end() ? it->second.has_dist : false;
}
const char* grammar_str,
const char* root = "root") {
GrammarEntry entry;
grammars_.emplace(h, std::move(entry));
return h;
}
const llama_model* model,
const char* grammar_str,
const std::vector<std::string>& trigger_patterns,
const std::vector<llama_token>& trigger_tokens,
const char* root = "root") {
GrammarEntry entry;
model, grammar_str, trigger_patterns, trigger_tokens, root);
if (!entry.sampler) return 0;
grammars_.emplace(h, std::move(entry));
return h;
}
if (h == 0) return 0;
auto it = grammars_.find(h);
if (it == grammars_.end()) return 0;
GrammarEntry entry;
grammars_.emplace(nh, std::move(entry));
return nh;
}
if (h != 0) grammars_.erase(h);
}
if (h == 0) return nullptr;
auto it = grammars_.find(h);
return it != grammars_.end() ? it->second.sampler : nullptr;
}
metrics_registry_[h] = metrics::BranchMetricsState{};
return h;
}
if (h == 0) return 0;
auto it = metrics_registry_.find(h);
if (it == metrics_registry_.end()) return 0;
metrics_registry_[nh] = it->second;
return nh;
}
if (h != 0) metrics_registry_.erase(h);
}
if (h == 0) return;
auto it = metrics_registry_.find(h);
if (it == metrics_registry_.end()) return;
if (!std::isfinite(surprisal)) return;
it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
it->second.model.count++;
}
if (h == 0) return;
auto it = metrics_registry_.find(h);
if (it == metrics_registry_.end()) return;
if (!std::isfinite(surprisal)) return;
it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
it->second.sampling.count++;
}
if (h == 0) return std::numeric_limits<float>::infinity();
auto it = metrics_registry_.find(h);
if (it == metrics_registry_.end() || it->second.model.count == 0)
return std::numeric_limits<float>::infinity();
return std::exp(it->second.model.nll_sum_nats /
static_cast<float>(it->second.model.count));
}
if (h == 0) return std::numeric_limits<float>::infinity();
auto it = metrics_registry_.find(h);
if (it == metrics_registry_.end() || it->second.sampling.count == 0)
return std::numeric_limits<float>::infinity();
return std::exp(it->second.sampling.nll_sum_nats /
static_cast<float>(it->second.sampling.count));
}
void decode_each(std::span<const DecodeEachItem> items) {
if (items.empty()) return;
const int32_t n = static_cast<int32_t>(items.size());
std::vector<BranchState*> states(n);
for (int32_t i = 0; i < n; ++i) {
states[i] =
get(items[i].handle);
if (!states[i]) {
throw std::runtime_error("BranchStore::decode_each - invalid handle at index " + std::to_string(i));
}
if (i > 0 && states[i]->ctx != states[0]->ctx) {
throw std::runtime_error("BranchStore::decode_each - all branches must share the same context");
}
}
std::vector<decode::EachItem> decode_items(n);
for (int32_t i = 0; i < n; ++i) {
decode_items[i].token = items[i].token;
decode_items[i].pos = states[i]->position;
decode_items[i].seq_id = states[i]->seq_id;
decode_items[i].output_logits = true;
}
if (
decode::each(states[0]->ctx, decode_items.data(), n, scratch_) != 0) {
throw std::runtime_error("BranchStore::decode_each - llama_decode failed");
}
llama_context* ctx = states[0]->ctx;
for (int32_t i = 0; i < n; ++i) {
if (states[i]->n_vocab <= 0) {
throw std::runtime_error("BranchStore::decode_each - invalid vocab size at index " + std::to_string(i));
}
assert(states[i]->logits_snapshot.size() >= static_cast<size_t>(states[i]->n_vocab));
std::memcpy(states[i]->logits_snapshot.data(), raw_logits,
states[i]->n_vocab * sizeof(float));
states[i]->has_logits = true;
states[i]->position += 1;
}
cells_used_ += static_cast<uint32_t>(items.size());
}
if (items.empty()) return;
const int32_t n = static_cast<int32_t>(items.size());
std::vector<BranchState*> states(n);
for (int32_t i = 0; i < n; ++i) {
states[i] =
get(items[i].handle);
if (!states[i]) {
throw std::runtime_error("BranchStore::decode_scatter - invalid handle at index " + std::to_string(i));
}
if (i > 0 && states[i]->ctx != states[0]->ctx) {
throw std::runtime_error("BranchStore::decode_scatter - all branches must share the same context");
}
}
llama_context* ctx = states[0]->ctx;
const int32_t batch_limit = static_cast<int32_t>(llama_n_batch(ctx));
std::vector<std::span<const llama_token>> spans(n);
for (int32_t i = 0; i < n; ++i) {
spans[i] = items[i].tokens;
}
for (const auto& chunk : chunks) {
if (chunk.oversized) {
int32_t idx = chunk.indices[0];
int32_t tc = static_cast<int32_t>(items[idx].tokens.size());
states[idx]->position, batch_limit,
states[idx]->seq_id) != 0) {
throw std::runtime_error("BranchStore::decode_scatter - decode::many failed for oversized item " + std::to_string(idx));
}
assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
states[idx]->n_vocab * sizeof(float));
states[idx]->has_logits = true;
states[idx]->position += tc;
continue;
}
std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
for (size_t k = 0; k < chunk.indices.size(); ++k) {
int32_t idx = chunk.indices[k];
scatter_items[k].tokens = items[idx].tokens;
scatter_items[k].start_pos = states[idx]->position;
scatter_items[k].seq_id = states[idx]->seq_id;
scatter_items[k].output_logits = true;
}
static_cast<int32_t>(scatter_items.size()),
scratch_) != 0) {
throw std::runtime_error("BranchStore::decode_scatter - decode::scatter failed");
}
int32_t cursor = 0;
for (size_t k = 0; k < scatter_items.size(); ++k) {
int32_t idx = chunk.indices[k];
int32_t item_n = static_cast<int32_t>(scatter_items[k].tokens.size());
const float* raw_logits =
logits::get(ctx, cursor + item_n - 1);
assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
states[idx]->n_vocab * sizeof(float));
states[idx]->has_logits = true;
states[idx]->position += static_cast<int32_t>(items[idx].tokens.size());
cursor += item_n;
}
}
for (int32_t i = 0; i < n; ++i) {
cells_used_ += static_cast<uint32_t>(items[i].tokens.size());
}
}
private:
void free_branch_resources(BranchState& slot) {
if (slot.sampler_chain != 0) {
slot.sampler_chain = 0;
}
if (slot.grammar != 0) {
slot.grammar = 0;
}
if (slot.boundary_tracker) {
delete slot.boundary_tracker;
slot.boundary_tracker = nullptr;
}
if (slot.metrics != 0) {
slot.metrics = 0;
}
}
void reset_slot(BranchState& slot) {
slot.in_use = false;
slot.generation = static_cast<uint16_t>(slot.generation + 1);
slot.ctx = nullptr;
slot.model = nullptr;
slot.position = 0;
slot.fork_head = 0;
slot.sampler_chain = 0;
slot.grammar = 0;
slot.metrics = 0;
slot.cached_params = CachedSamplingParams{};
slot.last_token = -1;
slot.last_candidates.clear();
slot.logits_snapshot.clear();
slot.has_logits = false;
slot.logit_bias.clear();
slot.steer_fn = nullptr;
slot.candidates_buffer.clear();
slot.n_vocab = 0;
slot.children.clear();
}
if (freelist_.empty()) {
size_t old_size = slots_.size();
size_t new_size = old_size * 2;
}
if (old_size >= new_size) {
}
slots_.resize(new_size);
for (size_t i = new_size; i-- > old_size; ) {
freelist_.push_back(static_cast<uint16_t>(i));
}
}
uint16_t index = freelist_.back();
freelist_.pop_back();
BranchState& slot = slots_[index];
slot.in_use = true;
}
BranchState* st =
get(handle);
if (!st) return;
BranchState* p =
get(st->parent);
if (p) {
auto& c = p->children;
c.erase(std::remove(c.begin(), c.end(), handle), c.end());
}
}
free_branch_resources(*st);
reset_slot(*st);
}
std::deque<BranchState> slots_;
std::vector<uint16_t> freelist_;
kv::tenancy::State tenancy_;
uint32_t cells_used_ = 0;
decode::Scratch scratch_;
std::unordered_map<SamplerChainHandle, SamplerChainEntry> sampler_chains_;
std::unordered_map<GrammarHandle, GrammarEntry> grammars_;
std::unordered_map<MetricsHandle, metrics::BranchMetricsState> metrics_registry_;
};
template <SamplingParamsLike P>
llama_context* ctx,
const llama_model* model,
BranchStore& s,
llama_pos start_pos,
const P& params,
const char* grammar_str = nullptr,
boundaries::BoundaryTracker* boundary_tracker = nullptr) {
if (!ctx || !model) {
}
auto [handle, seq_id] = s.allocate();
}
BranchState* state = s.get(handle);
if (!state) {
s.release(handle);
}
state->ctx = ctx;
state->model = model;
state->position = start_pos;
state->n_batch = n_batch;
const llama_vocab* vocab = llama_model_get_vocab(model);
state->n_vocab = llama_vocab_n_tokens(vocab);
state->logits_snapshot.resize(state->n_vocab);
state->has_logits = false;
state->candidates_buffer.resize(state->n_vocab);
state->sampler_chain = s.create_sampler(params);
if (grammar_str && grammar_str[0] != '\0') {
state->grammar = s.create_grammar(model, grammar_str);
}
state->boundary_tracker = boundary_tracker;
state->metrics = s.create_metrics();
handle, seq_id, start_pos);
return handle;
}
BranchState* src = s.get(source);
if (!src) {
}
auto [new_handle, new_seq_id] = s.allocate();
}
BranchState* dst = s.get(new_handle);
if (!dst) {
s.release(new_handle);
}
dst->ctx = src->ctx;
dst->model = src->model;
dst->seq_id = new_seq_id;
dst->position = src->position;
dst->fork_head = src->position;
dst->n_batch = src->n_batch;
dst->n_vocab = src->n_vocab;
#ifndef NDEBUG
assert(
kv::pos_max(src->ctx, new_seq_id) < 0 &&
"tenancy: acquired seq must be clean");
assert(dst->parent ==
INVALID_HANDLE && dst->children.empty() &&
"fresh slot must have no topology");
#endif
dst->parent = source;
src->children.push_back(new_handle);
if (src->sampler_chain != 0) {
dst->sampler_chain = s.clone_sampler(src->sampler_chain);
}
dst->cached_params = src->cached_params;
if (src->grammar != 0) {
dst->grammar = s.clone_grammar(src->grammar);
}
if (src->boundary_tracker) {
dst->boundary_tracker = src->boundary_tracker->clone().release();
}
if (src->metrics != 0) {
dst->metrics = s.clone_metrics(src->metrics);
}
dst->last_token = src->last_token;
dst->last_candidates = src->last_candidates;
dst->logits_snapshot = src->logits_snapshot;
dst->has_logits = src->has_logits;
dst->logit_bias = src->logit_bias;
dst->candidates_buffer.resize(dst->n_vocab);
source, new_handle, src->seq_id, new_seq_id);
return new_handle;
}
const llama_logit_bias* biases,
size_t n_biases,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("set_logit_bias: invalid branch handle");
}
state->logit_bias.assign(biases, biases + n_biases);
n_biases, handle);
}
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("clear_logit_bias: invalid branch handle");
}
state->logit_bias.clear();
LLOYAL_LOG_DEBUG(
"[branch::clear_logit_bias] Cleared biases on handle=%u", handle);
}
std::function<void(llama_token_data_array&)> steer_fn,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("set_steer: invalid branch handle");
}
state->steer_fn = std::move(steer_fn);
LLOYAL_LOG_DEBUG(
"[branch::set_steer] Set steer callback on handle=%u", handle);
}
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("clear_steer: invalid branch handle");
}
state->steer_fn = nullptr;
LLOYAL_LOG_DEBUG(
"[branch::clear_steer] Cleared steer callback on handle=%u", handle);
}
template <SamplingParamsLike P>
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("set_sampler_params: invalid branch handle");
}
if (new_params == state->cached_params && state->sampler_chain != 0) {
return;
}
if (state->sampler_chain != 0) {
s.free_sampler(state->sampler_chain);
}
state->sampler_chain = s.create_sampler(params);
state->cached_params = new_params;
LLOYAL_LOG_DEBUG(
"[branch::set_sampler_params] Rebuilt chain on handle=%u temp=%.3f",
handle, new_params.temperature);
}
const llama_model* model,
const char* grammar_str,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("set_grammar: invalid branch handle");
}
if (state->grammar != 0) {
s.free_grammar(state->grammar);
state->grammar = 0;
}
if (grammar_str && grammar_str[0] != '\0') {
state->grammar = s.create_grammar(model, grammar_str);
}
state->grammar != 0 ? "Set" : "Cleared", handle);
}
const llama_model* model,
const char* grammar_str,
const std::vector<std::string>& trigger_patterns,
const std::vector<llama_token>& trigger_tokens,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("set_grammar_lazy: invalid branch handle");
}
if (state->grammar != 0) {
s.free_grammar(state->grammar);
state->grammar = 0;
}
if (grammar_str && grammar_str[0] != '\0') {
state->grammar = s.create_grammar_lazy(
model, grammar_str, trigger_patterns, trigger_tokens);
}
state->grammar != 0 ? "Set" : "Cleared", handle);
}
BranchState* state = s.get(handle);
if (!state) return;
if (!state->children.empty())
throw std::runtime_error("prune: RESTRICT — branch has children. Use pruneSubtree() for CASCADE.");
s.release(handle);
}
std::vector<BranchHandle> stack{h}, post_order;
while (!stack.empty()) {
post_order.push_back(cur);
BranchState* st = s.get(cur);
if (st) for (auto child : st->children) stack.push_back(child);
}
for (auto it = post_order.rbegin(); it != post_order.rend(); ++it)
}
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("force_snapshot_logits: invalid branch handle");
}
if (state->n_vocab <= 0) {
throw std::runtime_error("force_snapshot_logits: invalid vocab size");
}
std::memcpy(state->logits_snapshot.data(), raw_logits,
state->n_vocab * sizeof(float));
state->has_logits = true;
}
const llama_token* tokens,
size_t n_tokens,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("prefill: invalid branch handle");
}
if (
decode::many(state->ctx, tokens,
static_cast<int32_t
>(n_tokens),
state->position, state->n_batch, state->seq_id) != 0) {
throw std::runtime_error("prefill: llama_decode failed");
}
state->position += static_cast<llama_pos>(n_tokens);
s.add_cells_used(static_cast<uint32_t>(n_tokens));
if (state->n_vocab <= 0) {
throw std::runtime_error("prefill: invalid vocab size");
}
std::memcpy(state->logits_snapshot.data(), raw_logits,
state->n_vocab * sizeof(float));
state->has_logits = true;
}
llama_token token,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) {
throw std::runtime_error("step: invalid branch handle");
}
if (
decode::one(state->ctx, token, state->position, state->seq_id,
true) != 0) {
throw std::runtime_error("step: llama_decode failed");
}
state->position += 1;
s.add_cells_used(1);
if (state->n_vocab <= 0) {
throw std::runtime_error("step: invalid vocab size");
}
std::memcpy(state->logits_snapshot.data(), raw_logits,
state->n_vocab * sizeof(float));
state->has_logits = true;
}
const BranchState* state = s.get(handle);
if (!state || !state->has_logits) {
return nullptr;
}
return state->logits_snapshot.data();
}
BranchState* state = s.get(handle);
llama_sampler* chain = state ? s.get_sampler_chain(state->sampler_chain) : nullptr;
if (!state || !chain) {
return -1;
}
if (!state->has_logits) {
LLOYAL_LOG_DEBUG(
"[branch::sample] No logits captured - call prefill()/step() first");
return -1;
}
for (int i = 0; i < state->n_vocab; i++) {
state->candidates_buffer[i] = llama_token_data{
static_cast<llama_token>(i),
state->logits_snapshot[i],
0.0f};
}
llama_token_data_array cur_p = {
state->candidates_buffer.data(),
static_cast<size_t>(state->n_vocab),
-1,
false};
llama_sampler* gram = s.get_grammar_sampler(state->grammar);
if (gram) {
}
if (!state->logit_bias.empty()) {
for (const auto& bias : state->logit_bias) {
if (bias.token >= 0 && bias.token < state->n_vocab) {
cur_p.data[bias.token].logit += bias.bias;
}
}
}
if (state->steer_fn) {
try {
state->steer_fn(cur_p);
} catch (const std::exception& e) {
}
}
if (cur_p.selected == -1) {
return -1;
}
llama_token token = cur_p.data[cur_p.selected].id;
state->last_token = token;
state->last_candidates.clear();
state->last_candidates.reserve(cur_p.size);
for (size_t i = 0; i < cur_p.size; i++) {
state->last_candidates.push_back(cur_p.data[i]);
}
return token;
}
llama_token token,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state) return;
llama_sampler* gram = s.get_grammar_sampler(state->grammar);
if (gram) {
}
llama_sampler* chain = s.get_sampler_chain(state->sampler_chain);
if (chain) {
}
if (state->metrics != 0 && state->has_logits) {
state->logits_snapshot.data(), state->n_vocab, token);
s.add_model_surprisal(state->metrics, ms);
}
if (state->metrics != 0 && !state->last_candidates.empty() &&
token == state->last_token) {
std::vector<float> candidate_logits;
std::vector<int32_t> candidate_ids;
candidate_logits.reserve(state->last_candidates.size());
candidate_ids.reserve(state->last_candidates.size());
for (const auto& cand : state->last_candidates) {
candidate_logits.push_back(cand.logit);
candidate_ids.push_back(cand.id);
}
candidate_logits.data(),
candidate_ids.data(),
static_cast<int>(candidate_logits.size()),
token
);
s.add_sampling_surprisal(state->metrics, ss);
}
}
float* logits,
int n_vocab,
BranchStore& s) {
BranchState* state = s.get(handle);
llama_sampler* gram = state ? s.get_grammar_sampler(state->grammar) : nullptr;
if (!state || !gram) return;
std::vector<llama_token_data>* candidates_ptr;
std::vector<llama_token_data> temp_buffer;
if (n_vocab == state->n_vocab && !state->candidates_buffer.empty()) {
candidates_ptr = &state->candidates_buffer;
} else {
temp_buffer.resize(n_vocab);
candidates_ptr = &temp_buffer;
}
auto& candidates = *candidates_ptr;
for (int i = 0; i < n_vocab; i++) {
candidates[i] = llama_token_data{
static_cast<llama_token>(i), logits[i], 0.0f};
}
llama_token_data_array cur_p = {
candidates.data(),
static_cast<size_t>(n_vocab),
-1,
false};
for (int i = 0; i < n_vocab; i++) {
logits[i] = candidates[i].logit;
}
}
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state || !state->has_logits) {
return {};
}
for (int i = 0; i < state->n_vocab; i++) {
state->candidates_buffer[i] = llama_token_data{
static_cast<llama_token>(i),
state->logits_snapshot[i],
0.0f};
}
llama_token_data_array cur_p = {
state->candidates_buffer.data(),
static_cast<size_t>(state->n_vocab),
-1,
false};
llama_sampler* gram = s.get_grammar_sampler(state->grammar);
if (gram) {
}
std::vector<std::pair<llama_token, float>> legal_priors;
float max_logit = -std::numeric_limits<float>::infinity();
for (size_t i = 0; i < cur_p.size; i++) {
if (std::isfinite(cur_p.data[i].logit)) {
legal_priors.emplace_back(cur_p.data[i].id, cur_p.data[i].logit);
if (cur_p.data[i].logit > max_logit) {
max_logit = cur_p.data[i].logit;
}
}
}
if (legal_priors.empty()) {
return {};
}
float sum_exp = 0.0f;
for (auto& [token, logit] : legal_priors) {
float exp_val = std::exp(logit - max_logit);
logit = exp_val;
sum_exp += exp_val;
}
for (auto& [token, prob] : legal_priors) {
prob /= sum_exp;
}
return legal_priors;
}
BranchState* state = s.get(handle);
if (!state || !state->has_logits) {
return -std::numeric_limits<float>::infinity();
}
for (int i = 0; i < state->n_vocab; i++) {
state->candidates_buffer[i] = llama_token_data{
static_cast<llama_token>(i),
state->logits_snapshot[i],
0.0f};
}
llama_token_data_array cur_p = {
state->candidates_buffer.data(),
static_cast<size_t>(state->n_vocab),
-1,
false};
llama_sampler* gram = s.get_grammar_sampler(state->grammar);
if (gram) {
}
float max_logit = -std::numeric_limits<float>::infinity();
for (size_t i = 0; i < cur_p.size; i++) {
if (std::isfinite(cur_p.data[i].logit) && cur_p.data[i].logit > max_logit) {
max_logit = cur_p.data[i].logit;
}
}
if (!std::isfinite(max_logit)) {
return -std::numeric_limits<float>::infinity();
}
float sum_exp = 0.0f;
for (size_t i = 0; i < cur_p.size; i++) {
if (std::isfinite(cur_p.data[i].logit)) {
sum_exp += std::exp(cur_p.data[i].logit - max_logit);
}
}
return max_logit + std::log(sum_exp);
}
llama_token token,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state || token < 0 || token >= state->n_vocab) {
return false;
}
llama_sampler* gram = s.get_grammar_sampler(state->grammar);
if (!gram) {
return true;
}
llama_token_data single_candidate = {
token,
state->has_logits ? state->logits_snapshot[token] : 0.0f,
0.0f
};
llama_token_data_array cur_p = {
&single_candidate,
1,
-1,
false
};
return std::isfinite(single_candidate.logit);
}
llama_token token,
float logsumexp,
BranchStore& s) {
BranchState* state = s.get(handle);
if (!state || !state->has_logits || token < 0 || token >= state->n_vocab) {
return 0.0f;
}
float logit = state->logits_snapshot[token];
return std::exp(logit - logsumexp);
}
llama_token token,
float logsumexp,
BranchStore& s) {
return 0.0f;
}
}
const BranchState* state = s.get(handle);
return state ? state->position : -1;
}
const BranchState* state = s.get(handle);
return state ? state->fork_head : 0;
}
const BranchState* state = s.get(handle);
if (!state || state->metrics == 0) {
return std::numeric_limits<float>::infinity();
}
return s.get_model_ppl(state->metrics);
}
const BranchState* state = s.get(handle);
if (!state || state->metrics == 0) {
return std::numeric_limits<float>::infinity();
}
return s.get_sampling_ppl(state->metrics);
}
const BranchState* state = s.get(handle);
if (!state || state->last_candidates.empty() || state->last_token < 0) {
return 0.0f;
}
std::vector<float> candidate_logits;
std::vector<int32_t> candidate_ids;
candidate_logits.reserve(state->last_candidates.size());
candidate_ids.reserve(state->last_candidates.size());
for (const auto& cand : state->last_candidates) {
candidate_logits.push_back(cand.logit);
candidate_ids.push_back(cand.id);
}
candidate_logits.data(),
candidate_ids.data(),
static_cast<int>(candidate_logits.size()),
state->last_token
);
return std::exp(-surprisal);
}
const BranchState* state = s.get(handle);
return state ? state->n_vocab : 0;
}
class Branch {
public:
: store_(store), handle_(
handle) {}
}
}
: store_(other.store_), handle_(other.handle_) {
}
if (this != &other) {
}
store_ = other.store_;
handle_ = other.handle_;
}
return *this;
}
template <SamplingParamsLike P>
llama_context* ctx,
const llama_model* model,
BranchStore& store,
llama_pos start_pos,
const P& params,
const char* grammar_str = nullptr,
boundaries::BoundaryTracker* boundary_tracker = nullptr) {
}
}
}
}
}
void prefill(
const llama_token* tokens,
size_t n) {
}
void step(llama_token token) {
}
}
}
void accept(llama_token token) {
}
bool is_eog(llama_token token)
const {
const BranchState* st = store_ ? store_->
get(handle_) :
nullptr;
}
template <SamplingParamsLike P>
}
const BranchState* st = store_ ? store_->
get(handle_) :
nullptr;
}
static const std::vector<BranchHandle> empty;
return store_ ? store_->
children(handle_) : empty;
}
bool isLeaf()
const {
return store_ ? store_->
isLeaf(handle_) :
true; }
private:
BranchStore* store_;
};
}
GrammarHandle create_grammar_lazy(const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const char *root="root")
Create a lazy grammar (unconstrained until trigger fires)
float get_sampling_ppl(MetricsHandle h) const
Get sampling-level perplexity from a metrics tracker.
bool isActive(BranchHandle h) const
Test whether a branch holds a KV lease.
void free_grammar(GrammarHandle h)
Free a grammar.
GrammarHandle create_grammar(const llama_model *model, const char *grammar_str, const char *root="root")
Create a grammar sampler and register it.
MetricsHandle clone_metrics(MetricsHandle h)
Clone a metrics tracker (for fork)
size_t available() const
Number of vacant seq_ids available for acquisition.
void release(BranchHandle handle)
Release a branch slot + evict its KV lease.
void drain()
Explicit teardown — evict all leases while context is alive.
float get_model_ppl(MetricsHandle h) const
Get model-level perplexity from a metrics tracker.
bool isLeaf(BranchHandle h) const
Test whether a branch is a leaf (no children)
void decode_each(std::span< const DecodeEachItem > items)
Decode one token per branch in a single GPU dispatch.
bool sampler_has_dist(SamplerChainHandle h) const
Check if a sampler chain ends with dist (stochastic) or greedy.
MetricsHandle create_metrics()
Create a metrics tracker and register it.
~BranchStore()
Destructor — frees CPU resources.
SamplerChainHandle clone_sampler(SamplerChainHandle h)
Clone a sampler chain (for fork)
llama_pos fork_head(BranchHandle h) const
Get a branch's fork head (parent position at fork time)
Allocation allocate()
Allocate a branch slot + KV lease atomically.
void add_cells_used(uint32_t n)
Increment cells_used counter (for standalone prefill/step outside BranchStore methods)
void free_sampler(SamplerChainHandle h)
Free a sampler chain.
void retainOnly(BranchHandle winner)
Keep only the winner — nuclear KV + CPU cleanup.
KvPressure kv_pressure() const
KV cache pressure snapshot — O(1), no tree walking.
SamplerChainHandle create_sampler(const P ¶ms)
Create a sampler chain and register it.
llama_sampler * get_grammar_sampler(GrammarHandle h) const
Dereference a grammar handle (non-owning)
void add_sampling_surprisal(MetricsHandle h, float surprisal)
Add sampling-level surprisal to a metrics tracker.
void init_tenancy(llama_context *ctx)
Initialize KV tenancy after context creation.
GrammarHandle clone_grammar(GrammarHandle h)
Clone a grammar (for fork)
void decode_scatter(std::span< const DecodeScatterItem > items)
Decode variable token counts per branch with auto-chunking.
void free_metrics(MetricsHandle h)
Free a metrics tracker.
const std::vector< BranchHandle > & children(BranchHandle h) const
Get a branch's child handles.
BranchState * get(BranchHandle handle)
Look up branch state by handle.
BranchHandle parent(BranchHandle h) const
Get a branch's parent handle.
void add_model_surprisal(MetricsHandle h, float surprisal)
Add model-level surprisal to a metrics tracker.
llama_sampler * get_sampler_chain(SamplerChainHandle h) const
Dereference a sampler chain handle (non-owning)
~Branch()
Destructor — CASCADE prunes entire subtree.
void accept(llama_token token)
Accept a token — advance grammar, penalty window, and metrics.
int n_vocab() const
Vocabulary size.
void setGrammar(const char *grammar_str)
Replace grammar constraint (nullptr/empty to remove)
const std::vector< BranchHandle > & childHandles() const
Child branch handles (empty if leaf)
void setSamplerParams(const P ¶ms)
Replace sampler chain with new parameters (memoized)
BranchHandle handle() const
Underlying opaque handle (for interop with free functions)
const float * logits() const
Get the branch's captured logits snapshot.
bool is_eog(llama_token token) const
Check if a token is end-of-generation for this branch's model.
BranchHandle parentHandle() const
Parent branch handle, or INVALID_HANDLE if root.
llama_pos position() const
Current decode position (token count)
llama_token sample()
Sample a token from captured logits.
void step(llama_token token)
Decode one token and capture logits (generation step)
void pruneSubtree()
CASCADE prune — removes entire subtree.
float perplexity() const
Model-level perplexity (from raw logits, pre-filter)
bool isActive() const
True if this branch holds a KV lease.
llama_pos forkHead() const
Parent's position at fork time (0 for root branches)
void force_snapshot_logits()
Force-copy shared logits buffer into this branch's snapshot.
bool isLeaf() const
True if this branch has no children.
void prefill(const llama_token *tokens, size_t n)
Decode multiple tokens and capture logits atomically (prompt injection)
bool valid() const
True if this Branch holds a valid handle.
void prune()
RESTRICT prune (throws if children exist)
Branch fork()
Fork: allocates slot + lease, records topology edge.
Branch & operator=(Branch &&other) noexcept
static Branch create(llama_context *ctx, const llama_model *model, BranchStore &store, llama_pos start_pos, const P ¶ms, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Factory: allocates slot + lease from store.
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Batch Decoding Operations.
Grammar-Constrained Sampling.
constexpr llama_seq_id NO_LEASE
Sentinel value indicating a branch has no KV residency.
Zero-copy logits access with clear lifetime semantics.
Distribution Metrics for Test-Time Alignment.
void prefill(BranchHandle handle, const llama_token *tokens, size_t n_tokens, BranchStore &s)
Decode multiple tokens and capture logits atomically (prompt prefill)
void set_logit_bias(BranchHandle handle, const llama_logit_bias *biases, size_t n_biases, BranchStore &s)
float get_perplexity(BranchHandle handle, BranchStore &s)
Get model-level perplexity (from raw logits)
void apply_grammar(BranchHandle handle, float *logits, int n_vocab, BranchStore &s)
Apply grammar constraints to an external logits buffer.
int32_t GrammarHandle
Handle to a grammar sampler in BranchStore's registry (0 = invalid/none)
uint32_t BranchHandle
Opaque handle to a branch slot.
void step(BranchHandle handle, llama_token token, BranchStore &s)
Decode a single token and capture logits (generation step)
constexpr int DEFAULT_N_BATCH
Default batch size for decode operations.
void set_grammar_lazy(BranchHandle handle, const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, BranchStore &s)
Set lazy grammar on a branch (unconstrained until trigger fires)
void set_steer(BranchHandle handle, std::function< void(llama_token_data_array &)> steer_fn, BranchStore &s)
constexpr uint32_t GEN_SHIFT
Bit shift for generation field.
void accept_token(BranchHandle handle, llama_token token, BranchStore &s)
Accept a sampled token, advancing grammar and sampler state.
void force_snapshot_logits(BranchHandle handle, BranchStore &s)
Force-copy the shared llama.cpp logits buffer into this branch's private snapshot.
llama_pos get_fork_head(BranchHandle handle, BranchStore &s)
Get the branch's fork head (parent position at fork time)
float get_sampling_perplexity(BranchHandle handle, BranchStore &s)
Get sampling-level perplexity (from filtered distribution)
float get_legal_logsumexp(BranchHandle handle, BranchStore &s)
Compute log-sum-exp over grammar-legal logits.
void clear_logit_bias(BranchHandle handle, BranchStore &s)
Clear all logit biases from a branch.
uint16_t handle_generation(BranchHandle h)
Extract generation counter from a branch handle.
const float * get_logits(BranchHandle handle, BranchStore &s)
Get the branch's captured logits snapshot.
void set_sampler_params(BranchHandle handle, const P ¶ms, BranchStore &s)
Replace a branch's sampler chain with new parameters.
void prune(BranchHandle handle, BranchStore &s)
Prune a leaf branch (RESTRICT — throws if children exist)
BranchHandle create(llama_context *ctx, const llama_model *model, BranchStore &s, llama_pos start_pos, const P ¶ms, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Create a new branch with sampler chain, optional grammar, and metrics.
CachedSamplingParams snapshot_params(const P &p)
Snapshot sampling params for memoization comparison.
constexpr uint32_t INDEX_MASK
Mask for slot index field.
BranchHandle fork(BranchHandle source, BranchStore &s)
Fork a branch into a new independent sequence.
int get_n_vocab(BranchHandle handle, BranchStore &s)
Get the branch's vocabulary size.
llama_token sample(BranchHandle handle, BranchStore &s)
Sample a token from the branch's captured logits.
void pruneSubtree(BranchHandle h, BranchStore &s)
Prune a branch and all descendants (CASCADE — iterative post-order)
constexpr BranchHandle INVALID_HANDLE
Null handle sentinel.
float get_token_prior(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token, checking grammar legality first.
constexpr llama_seq_id NO_LEASE
Branch has no KV residency.
int32_t MetricsHandle
Handle to a metrics tracker in BranchStore's registry (0 = invalid/none)
llama_pos get_position(BranchHandle handle, BranchStore &s)
Get the branch's current decode position.
void set_grammar(BranchHandle handle, const llama_model *model, const char *grammar_str, BranchStore &s)
Replace a branch's grammar constraint.
void clear_steer(BranchHandle handle, BranchStore &s)
Clear the steer callback from a branch.
float get_token_prior_assume_legal(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token known to be grammar-legal.
BranchHandle make_handle(uint16_t index, uint16_t generation)
Construct a branch handle from index and generation.
uint16_t handle_index(BranchHandle h)
Extract slot index from a branch handle.
int32_t SamplerChainHandle
Handle to a sampler chain in BranchStore's registry (0 = invalid/none)
std::vector< std::pair< llama_token, float > > get_legal_priors(BranchHandle handle, BranchStore &s)
Get grammar-legal tokens with renormalized probabilities.
bool is_token_legal(BranchHandle handle, llama_token token, BranchStore &s)
Check if a token is legal under grammar constraints.
float get_last_sampling_prior(BranchHandle handle, BranchStore &s)
Get the last sampled token's prior from the filtered distribution.
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
constexpr T as_value(const X &x, T def)
Extract value from either T or std::optional<T> with fallback.
void free_sampler(llama_sampler *smpl)
Free a grammar sampler.
llama_sampler * clone_sampler(llama_sampler *smpl)
Clone a grammar sampler (for fork/branching).
llama_sampler * init_sampler(const llama_model *model, const std::string &grammar_str, const std::string &root_rule="root")
Initialize a grammar sampler from GBNF grammar string.
llama_sampler * init_lazy_sampler(const llama_model *model, const std::string &grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const std::string &root_rule="root")
Initialize a lazy grammar sampler from GBNF grammar string.
void accept(llama_sampler *smpl, llama_token token)
Accept a token into grammar state.
void apply(llama_sampler *smpl, llama_token_data_array *cur_p)
Apply grammar constraint to candidates.
void evict_all(State &s)
Evict every leased seq_id.
llama_seq_id acquire(State &s)
Acquire a seq_id from the vacant pool.
size_t available(const State &s)
Number of vacant seq_ids available for acquisition.
void evict(State &s, llama_seq_id seq)
Evict a seq_id — strip all KV tags then release.
void retain(State &s, llama_seq_id keep)
Nuclear retain — keep one seq, rebuild vacancy from scratch.
State init(llama_context *ctx, llama_seq_id n_seq_max)
Initialize tenancy with all seq_ids vacant.
void release(State &s, llama_seq_id seq)
Release a seq_id back to vacant — bookkeeping only, no KV calls.
void seq_cp(llama_context *ctx, llama_seq_id src, llama_seq_id dst, llama_pos p0=0, llama_pos p1=-1)
Copy KV cache from one sequence to another.
llama_pos pos_max(llama_context *ctx, llama_seq_id seq)
Get maximum position in KV cache sequence.
float * get(llama_context *ctx, int32_t index=-1)
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
void apply(llama_sampler *chain, llama_token_data_array *cur_p)
Apply a sampler chain to a candidate array.
void accept(llama_sampler *chain, llama_token token)
Accept a token into the sampler chain.
llama_sampler * clone_chain(llama_sampler *chain)
Clone a sampler chain.
llama_sampler * create_chain(const P ¶ms)
Create a persistent sampler chain from parameters.
void free_chain(llama_sampler *chain)
Free a sampler chain.
bool is_eog(const llama_vocab *vocab, llama_token token)
Check if token is end-of-generation marker.
Token Sampling Operations.
bool in_use
True when slot is allocated to an active branch.
const llama_model * model
Llama model (not owned, must outlive branch)
bool has_logits
True only after force_snapshot_logits(), prefill(), or step()
std::function< void(llama_token_data_array &)> steer_fn
Dynamic logit callback, NOT cloned on fork.
int n_batch
Batch size for decode operations.
llama_seq_id seq_id
KV cache sequence identifier (NO_LEASE when inactive)
llama_pos position
Current decode position in the sequence.
std::vector< llama_logit_bias > logit_bias
Static token biases, cloned on fork.
std::vector< llama_token_data > last_candidates
Filtered candidates from last sample()
std::vector< float > logits_snapshot
Captured logit distribution (n_vocab floats)
std::vector< llama_token_data > candidates_buffer
Reusable scratch buffer for sampling (avoids O(n_vocab) allocs per sample call).
boundaries::BoundaryTracker * boundary_tracker
Token boundary detector (owned, optional)
SamplerChainHandle sampler_chain
Handle into BranchStore's sampler registry.
llama_context * ctx
Llama context (not owned, must outlive branch)
BranchHandle parent
Parent branch (INVALID_HANDLE if root)
int n_vocab
Vocabulary size (cached for buffer pre-allocation)
GrammarHandle grammar
Handle into BranchStore's grammar registry.
std::vector< BranchHandle > children
Child branches forked from this one.
llama_token last_token
Last token returned by sample()
MetricsHandle metrics
Handle into BranchStore's metrics registry.
llama_pos fork_head
Parent's position at fork time (0 for root branches)
CachedSamplingParams cached_params
Params used to create current chain (for memoization)
uint16_t generation
Slot generation counter (for ABA prevention)
bool operator==(const CachedSamplingParams &) const =default
std::span< const llama_token > tokens
GrammarEntry & operator=(GrammarEntry &&o) noexcept
uint32_t remaining
n_ctx - cells_used (clamped to 0)
uint32_t n_ctx
Total KV capacity.
uint32_t cells_used
Cells allocated since last reset.
SamplerChainEntry()=default
SamplerChainEntry & operator=(SamplerChainEntry &&o) noexcept
bool has_dist
True if chain ends with dist (temp > 0), false if greedy.
llama_context * ctx
Context for KV operations (nullptr after drain)