65#include <llama/llama.h>
135 return static_cast<uint16_t
>(h >>
GEN_SHIFT);
145 return (
static_cast<uint32_t
>(generation) <<
GEN_SHIFT) | index;
173 :
chain(o.chain),
has_dist(o.has_dist) { o.chain =
nullptr; }
200 sampler = o.sampler; o.sampler =
nullptr;
236template <SamplingParamsLike P>
238 using ::lloyal::detail::as_value;
240 as_value(p.temperature, 0.8f),
241 as_value(p.top_k,
static_cast<int32_t
>(40)),
242 as_value(p.top_p, 0.95f),
243 as_value(p.typical_p, 1.0f),
244 as_value(p.min_p, 0.05f),
245 as_value(p.penalty_repeat, 1.0f),
246 as_value(p.penalty_freq, 0.0f),
247 as_value(p.penalty_present, 0.0f),
248 as_value(p.penalty_last_n,
static_cast<int32_t
>(64)),
249 as_value(p.seed,
static_cast<uint32_t
>(0)),
276 llama_context*
ctx =
nullptr;
291 std::function<void(llama_token_data_array&)>
steer_fn;
400 if (initial_capacity < 2) {
401 initial_capacity = 2;
403 slots_.resize(initial_capacity);
405 slots_[0].in_use =
true;
406 slots_[0].generation = 0xFFFF;
410 for (
size_t i = initial_capacity; i-- > 1; ) {
411 freelist_.push_back(
static_cast<uint16_t
>(i));
418 if (tenancy_.
ctx !=
nullptr) {
421 for (
size_t i = 1; i < slots_.size(); ++i) {
422 if (slots_[i].in_use) {
423 free_branch_resources(slots_[i]);
452 return {handle, seq};
472 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
478 cells_used_ = (unique <= cells_used_) ? cells_used_ - unique : 0;
483 free_branch_resources(*st);
488 if (freelist_.size() == slots_.size() - 1) {
513 if (tenancy_.
ctx ==
nullptr)
return;
515 for (
size_t i = 1; i < slots_.size(); ++i) {
516 if (slots_[i].in_use) {
517 free_branch_resources(slots_[i]);
518 reset_slot(slots_[i]);
521 tenancy_.
ctx =
nullptr;
536 if (!w)
throw std::runtime_error(
"retainOnly: invalid winner handle");
537 if (w->
seq_id ==
NO_LEASE)
throw std::runtime_error(
"retainOnly: winner has no lease");
540 std::vector<BranchHandle> losers;
541 for (
size_t i = 1; i < slots_.size(); ++i) {
542 if (!slots_[i].in_use)
continue;
544 if (h == winner)
continue;
547 for (
auto h : losers)
548 release_slot_only(h);
552 cells_used_ =
static_cast<uint32_t
>(w->
position);
571 if (!tenancy_.
ctx)
return {0, 0, 0};
572 uint32_t n_ctx =
static_cast<uint32_t
>(llama_n_ctx(tenancy_.
ctx));
573 uint32_t remaining = (n_ctx > cells_used_) ? n_ctx - cells_used_ : 0;
574 return { n_ctx, cells_used_, remaining };
606 static const std::vector<BranchHandle> empty;
618 return st ? st->
children.empty() :
true;
647 if (index == 0)
return nullptr;
649 if (index >= slots_.size())
return nullptr;
671 template <SamplingParamsLike P>
677 entry.
has_dist = (temperature > 0.0f);
678 sampler_chains_.emplace(h, std::move(entry));
688 if (h == 0)
return 0;
689 auto it = sampler_chains_.find(h);
690 if (it == sampler_chains_.end())
return 0;
694 entry.
has_dist = it->second.has_dist;
695 sampler_chains_.emplace(nh, std::move(entry));
704 if (h != 0) sampler_chains_.erase(h);
713 if (h == 0)
return nullptr;
714 auto it = sampler_chains_.find(h);
715 return it != sampler_chains_.end() ? it->second.chain :
nullptr;
724 if (h == 0)
return false;
725 auto it = sampler_chains_.find(h);
726 return it != sampler_chains_.end() ? it->second.has_dist :
false;
739 const char* grammar_str,
740 const char* root =
"root") {
744 grammars_.emplace(h, std::move(entry));
758 const llama_model* model,
759 const char* grammar_str,
760 const std::vector<std::string>& trigger_patterns,
761 const std::vector<llama_token>& trigger_tokens,
762 const char* root =
"root") {
766 model, grammar_str, trigger_patterns, trigger_tokens, root);
768 grammars_.emplace(h, std::move(entry));
778 if (h == 0)
return 0;
779 auto it = grammars_.find(h);
780 if (it == grammars_.end())
return 0;
784 grammars_.emplace(nh, std::move(entry));
793 if (h != 0) grammars_.erase(h);
802 if (h == 0)
return nullptr;
803 auto it = grammars_.find(h);
804 return it != grammars_.end() ? it->second.sampler :
nullptr;
825 if (h == 0)
return 0;
826 auto it = metrics_registry_.find(h);
827 if (it == metrics_registry_.end())
return 0;
829 metrics_registry_[nh] = it->second;
838 if (h != 0) metrics_registry_.erase(h);
848 auto it = metrics_registry_.find(h);
849 if (it == metrics_registry_.end())
return;
850 if (!std::isfinite(surprisal))
return;
851 it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
852 it->second.model.count++;
862 auto it = metrics_registry_.find(h);
863 if (it == metrics_registry_.end())
return;
864 if (!std::isfinite(surprisal))
return;
865 it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
866 it->second.sampling.count++;
875 if (h == 0)
return std::numeric_limits<float>::infinity();
876 auto it = metrics_registry_.find(h);
877 if (it == metrics_registry_.end() || it->second.model.count == 0)
878 return std::numeric_limits<float>::infinity();
879 return std::exp(it->second.model.nll_sum_nats /
880 static_cast<float>(it->second.model.count));
889 if (h == 0)
return std::numeric_limits<float>::infinity();
890 auto it = metrics_registry_.find(h);
891 if (it == metrics_registry_.end() || it->second.sampling.count == 0)
892 return std::numeric_limits<float>::infinity();
893 return std::exp(it->second.sampling.nll_sum_nats /
894 static_cast<float>(it->second.sampling.count));
922 if (items.empty())
return;
924 const int32_t n =
static_cast<int32_t
>(items.size());
927 std::vector<BranchState*> states(n);
928 for (int32_t i = 0; i < n; ++i) {
929 states[i] =
get(items[i].handle);
931 throw std::runtime_error(
"BranchStore::decode_each - invalid handle at index " + std::to_string(i));
933 if (i > 0 && states[i]->ctx != states[0]->ctx) {
934 throw std::runtime_error(
"BranchStore::decode_each - all branches must share the same context");
939 std::vector<decode::EachItem> decode_items(n);
940 for (int32_t i = 0; i < n; ++i) {
941 decode_items[i].token = items[i].token;
942 decode_items[i].pos = states[i]->position;
943 decode_items[i].seq_id = states[i]->seq_id;
944 decode_items[i].output_logits =
true;
948 if (
decode::each(states[0]->ctx, decode_items.data(), n, scratch_) != 0) {
949 throw std::runtime_error(
"BranchStore::decode_each - llama_decode failed");
953 llama_context* ctx = states[0]->ctx;
954 for (int32_t i = 0; i < n; ++i) {
956 if (states[i]->n_vocab <= 0) {
957 throw std::runtime_error(
"BranchStore::decode_each - invalid vocab size at index " + std::to_string(i));
959 assert(states[i]->logits_snapshot.size() >=
static_cast<size_t>(states[i]->n_vocab));
960 std::memcpy(states[i]->logits_snapshot.data(), raw_logits,
961 states[i]->n_vocab *
sizeof(
float));
962 states[i]->has_logits =
true;
963 states[i]->position += 1;
965 cells_used_ +=
static_cast<uint32_t
>(items.size());
994 if (items.empty())
return;
996 const int32_t n =
static_cast<int32_t
>(items.size());
999 std::vector<BranchState*> states(n);
1000 for (int32_t i = 0; i < n; ++i) {
1001 states[i] =
get(items[i].handle);
1003 throw std::runtime_error(
"BranchStore::decode_scatter - invalid handle at index " + std::to_string(i));
1005 if (i > 0 && states[i]->ctx != states[0]->ctx) {
1006 throw std::runtime_error(
"BranchStore::decode_scatter - all branches must share the same context");
1010 llama_context* ctx = states[0]->ctx;
1011 const int32_t batch_limit =
static_cast<int32_t
>(llama_n_batch(ctx));
1014 std::vector<std::span<const llama_token>> spans(n);
1015 for (int32_t i = 0; i < n; ++i) {
1016 spans[i] = items[i].tokens;
1021 for (
const auto& chunk : chunks) {
1022 if (chunk.oversized) {
1023 int32_t idx = chunk.indices[0];
1024 int32_t tc =
static_cast<int32_t
>(items[idx].tokens.size());
1027 states[idx]->position, batch_limit,
1028 states[idx]->seq_id) != 0) {
1029 throw std::runtime_error(
"BranchStore::decode_scatter - decode::many failed for oversized item " + std::to_string(idx));
1033 assert(states[idx]->logits_snapshot.size() >=
static_cast<size_t>(states[idx]->n_vocab));
1034 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1035 states[idx]->n_vocab *
sizeof(
float));
1036 states[idx]->has_logits =
true;
1037 states[idx]->position += tc;
1042 std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
1043 for (
size_t k = 0; k < chunk.indices.size(); ++k) {
1044 int32_t idx = chunk.indices[k];
1045 scatter_items[k].tokens = items[idx].tokens;
1046 scatter_items[k].start_pos = states[idx]->position;
1047 scatter_items[k].seq_id = states[idx]->seq_id;
1048 scatter_items[k].output_logits =
true;
1052 static_cast<int32_t
>(scatter_items.size()),
1054 throw std::runtime_error(
"BranchStore::decode_scatter - decode::scatter failed");
1059 for (
size_t k = 0; k < scatter_items.size(); ++k) {
1060 int32_t idx = chunk.indices[k];
1061 int32_t item_n =
static_cast<int32_t
>(scatter_items[k].tokens.size());
1063 const float* raw_logits =
logits::get(ctx, cursor + item_n - 1);
1064 assert(states[idx]->logits_snapshot.size() >=
static_cast<size_t>(states[idx]->n_vocab));
1065 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1066 states[idx]->n_vocab *
sizeof(
float));
1067 states[idx]->has_logits =
true;
1068 states[idx]->position +=
static_cast<int32_t
>(items[idx].tokens.size());
1075 for (int32_t i = 0; i < n; ++i) {
1076 cells_used_ +=
static_cast<uint32_t
>(items[i].tokens.size());
1102 void reset_slot(BranchState& slot) {
1103 slot.in_use =
false;
1104 slot.generation =
static_cast<uint16_t
>(slot.generation + 1);
1106 slot.model =
nullptr;
1110 slot.sampler_chain = 0;
1113 slot.cached_params = CachedSamplingParams{};
1114 slot.last_token = -1;
1115 slot.last_candidates.clear();
1116 slot.logits_snapshot.clear();
1117 slot.has_logits =
false;
1118 slot.logit_bias.clear();
1119 slot.steer_fn =
nullptr;
1120 slot.candidates_buffer.clear();
1124 slot.children.clear();
1129 if (freelist_.empty()) {
1130 size_t old_size = slots_.size();
1131 size_t new_size = old_size * 2;
1135 if (old_size >= new_size) {
1139 slots_.resize(new_size);
1140 for (
size_t i = new_size; i-- > old_size; ) {
1141 freelist_.push_back(
static_cast<uint16_t
>(i));
1144 uint16_t index = freelist_.back();
1145 freelist_.pop_back();
1146 BranchState& slot = slots_[index];
1155 BranchState* st =
get(handle);
1159 BranchState* p =
get(st->parent);
1161 auto& c = p->children;
1162 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
1165 free_branch_resources(*st);
1172 std::deque<BranchState> slots_;
1173 std::vector<uint16_t> freelist_;
1178 kv::tenancy::State tenancy_;
1183 uint32_t cells_used_ = 0;
1187 decode::Scratch scratch_;
1191 std::unordered_map<SamplerChainHandle, SamplerChainEntry> sampler_chains_;
1194 std::unordered_map<GrammarHandle, GrammarEntry> grammars_;
1197 std::unordered_map<MetricsHandle, metrics::BranchMetricsState> metrics_registry_;
1224template <SamplingParamsLike P>
1227 const llama_model* model,
1229 llama_pos start_pos,
1232 const char* grammar_str =
nullptr,
1234 if (!ctx || !model) {
1239 auto [handle, seq_id] = s.
allocate();
1251 state->
model = model;
1256 const llama_vocab* vocab = llama_model_get_vocab(model);
1257 state->
n_vocab = llama_vocab_n_tokens(vocab);
1265 if (grammar_str && grammar_str[0] !=
'\0') {
1272 LLOYAL_LOG_DEBUG(
"[branch::create] Created branch handle=%u seq=%d pos=%d",
1273 handle, seq_id, start_pos);
1307 auto [new_handle, new_seq_id] = s.
allocate();
1321 dst->
seq_id = new_seq_id;
1328 assert(
kv::pos_max(src->
ctx, new_seq_id) < 0 &&
"tenancy: acquired seq must be clean");
1337 src->
children.push_back(new_handle);
1369 LLOYAL_LOG_DEBUG(
"[branch::fork] Forked handle=%u -> handle=%u seq=%d->%d",
1370 source, new_handle, src->
seq_id, new_seq_id);
1397 const llama_logit_bias* biases,
1404 throw std::runtime_error(
"set_logit_bias: invalid branch handle");
1408 state->
logit_bias.assign(biases, biases + n_biases);
1428 throw std::runtime_error(
"clear_logit_bias: invalid branch handle");
1433 LLOYAL_LOG_DEBUG(
"[branch::clear_logit_bias] Cleared biases on handle=%u", handle);
1464 std::function<
void(llama_token_data_array&)> steer_fn,
1470 throw std::runtime_error(
"set_steer: invalid branch handle");
1473 state->
steer_fn = std::move(steer_fn);
1475 LLOYAL_LOG_DEBUG(
"[branch::set_steer] Set steer callback on handle=%u", handle);
1492 throw std::runtime_error(
"clear_steer: invalid branch handle");
1497 LLOYAL_LOG_DEBUG(
"[branch::clear_steer] Cleared steer callback on handle=%u", handle);
1515template <SamplingParamsLike P>
1519 throw std::runtime_error(
"set_sampler_params: invalid branch handle");
1536 LLOYAL_LOG_DEBUG(
"[branch::set_sampler_params] Rebuilt chain on handle=%u temp=%.3f",
1554 const llama_model* model,
1555 const char* grammar_str,
1559 throw std::runtime_error(
"set_grammar: invalid branch handle");
1569 if (grammar_str && grammar_str[0] !=
'\0') {
1574 state->
grammar != 0 ?
"Set" :
"Cleared", handle);
1593 const llama_model* model,
1594 const char* grammar_str,
1595 const std::vector<std::string>& trigger_patterns,
1596 const std::vector<llama_token>& trigger_tokens,
1600 throw std::runtime_error(
"set_grammar_lazy: invalid branch handle");
1608 if (grammar_str && grammar_str[0] !=
'\0') {
1610 model, grammar_str, trigger_patterns, trigger_tokens);
1614 state->
grammar != 0 ?
"Set" :
"Cleared", handle);
1631 throw std::runtime_error(
"prune: RESTRICT — branch has children. Use pruneSubtree() for CASCADE.");
1645 std::vector<BranchHandle> stack{h}, post_order;
1646 while (!stack.empty()) {
1648 post_order.push_back(cur);
1650 if (st)
for (
auto child : st->
children) stack.push_back(child);
1652 for (
auto it = post_order.rbegin(); it != post_order.rend(); ++it)
1680 throw std::runtime_error(
"force_snapshot_logits: invalid branch handle");
1687 throw std::runtime_error(
"force_snapshot_logits: invalid vocab size");
1691 state->
n_vocab *
sizeof(
float));
1713 const llama_token* tokens,
1720 throw std::runtime_error(
"prefill: invalid branch handle");
1726 throw std::runtime_error(
"prefill: llama_decode failed");
1729 state->
position +=
static_cast<llama_pos
>(n_tokens);
1736 throw std::runtime_error(
"prefill: invalid vocab size");
1740 state->
n_vocab *
sizeof(
float));
1764 throw std::runtime_error(
"step: invalid branch handle");
1768 throw std::runtime_error(
"step: llama_decode failed");
1777 throw std::runtime_error(
"step: invalid vocab size");
1781 state->
n_vocab *
sizeof(
float));
1826 if (!state || !chain) {
1832 LLOYAL_LOG_DEBUG(
"[branch::sample] No logits captured - call prefill()/step() first");
1837 for (
int i = 0; i < state->
n_vocab; i++) {
1839 static_cast<llama_token
>(i),
1844 llama_token_data_array cur_p = {
1846 static_cast<size_t>(state->
n_vocab),
1860 if (bias.token >= 0 && bias.token < state->
n_vocab) {
1861 cur_p.data[bias.token].logit += bias.bias;
1870 }
catch (
const std::exception& e) {
1879 if (cur_p.selected == -1) {
1883 llama_token token = cur_p.data[cur_p.selected].id;
1891 for (
size_t i = 0; i < cur_p.size; i++) {
1946 std::vector<float> candidate_logits;
1947 std::vector<int32_t> candidate_ids;
1952 candidate_logits.push_back(cand.logit);
1953 candidate_ids.push_back(cand.id);
1958 candidate_logits.data(),
1959 candidate_ids.data(),
1960 static_cast<int>(candidate_logits.size()),
1990 if (!state || !gram)
return;
1993 std::vector<llama_token_data>* candidates_ptr;
1994 std::vector<llama_token_data> temp_buffer;
1999 temp_buffer.resize(n_vocab);
2000 candidates_ptr = &temp_buffer;
2003 auto& candidates = *candidates_ptr;
2004 for (
int i = 0; i < n_vocab; i++) {
2005 candidates[i] = llama_token_data{
2006 static_cast<llama_token
>(i), logits[i], 0.0f};
2009 llama_token_data_array cur_p = {
2011 static_cast<size_t>(n_vocab),
2018 for (
int i = 0; i < n_vocab; i++) {
2019 logits[i] = candidates[i].logit;
2048 for (
int i = 0; i < state->
n_vocab; i++) {
2050 static_cast<llama_token
>(i),
2055 llama_token_data_array cur_p = {
2057 static_cast<size_t>(state->
n_vocab),
2069 std::vector<std::pair<llama_token, float>> legal_priors;
2070 float max_logit = -std::numeric_limits<float>::infinity();
2072 for (
size_t i = 0; i < cur_p.size; i++) {
2073 if (std::isfinite(cur_p.data[i].logit)) {
2074 legal_priors.emplace_back(cur_p.data[i].id, cur_p.data[i].logit);
2075 if (cur_p.data[i].logit > max_logit) {
2076 max_logit = cur_p.data[i].logit;
2081 if (legal_priors.empty()) {
2086 float sum_exp = 0.0f;
2087 for (
auto& [token, logit] : legal_priors) {
2088 float exp_val = std::exp(logit - max_logit);
2094 for (
auto& [token, prob] : legal_priors) {
2098 return legal_priors;
2121 return -std::numeric_limits<float>::infinity();
2125 for (
int i = 0; i < state->
n_vocab; i++) {
2127 static_cast<llama_token
>(i),
2132 llama_token_data_array cur_p = {
2134 static_cast<size_t>(state->
n_vocab),
2145 float max_logit = -std::numeric_limits<float>::infinity();
2146 for (
size_t i = 0; i < cur_p.size; i++) {
2147 if (std::isfinite(cur_p.data[i].logit) && cur_p.data[i].logit > max_logit) {
2148 max_logit = cur_p.data[i].logit;
2152 if (!std::isfinite(max_logit)) {
2153 return -std::numeric_limits<float>::infinity();
2156 float sum_exp = 0.0f;
2157 for (
size_t i = 0; i < cur_p.size; i++) {
2158 if (std::isfinite(cur_p.data[i].logit)) {
2159 sum_exp += std::exp(cur_p.data[i].logit - max_logit);
2163 return max_logit + std::log(sum_exp);
2184 if (!state || token < 0 || token >= state->
n_vocab) {
2195 llama_token_data single_candidate = {
2201 llama_token_data_array cur_p = {
2211 return std::isfinite(single_candidate.logit);
2241 return std::exp(logit - logsumexp);
2282 return state ? state->
position : -1;
2310 if (!state || state->
metrics == 0) {
2311 return std::numeric_limits<float>::infinity();
2330 if (!state || state->
metrics == 0) {
2331 return std::numeric_limits<float>::infinity();
2355 std::vector<float> candidate_logits;
2356 std::vector<int32_t> candidate_ids;
2361 candidate_logits.push_back(cand.logit);
2362 candidate_ids.push_back(cand.id);
2367 candidate_logits.data(),
2368 candidate_ids.data(),
2369 static_cast<int>(candidate_logits.size()),
2374 return std::exp(-surprisal);
2386 return state ? state->
n_vocab : 0;
2414 : store_(store), handle_(
handle) {}
2424 : store_(other.store_), handle_(other.handle_) {
2429 if (
this != &other) {
2433 store_ = other.store_;
2434 handle_ = other.handle_;
2444 template <SamplingParamsLike P>
2447 const llama_model* model,
2449 llama_pos start_pos,
2452 const char* grammar_str =
nullptr,
2455 return Branch(&store, h);
2461 return Branch(store_, h);
2485 void prefill(
const llama_token* tokens,
size_t n) {
2522 template <SamplingParamsLike P>
2555 static const std::vector<BranchHandle> empty;
2556 return store_ ? store_->
children(handle_) : empty;
Stub BoundaryTracker - does nothing.
virtual std::unique_ptr< BoundaryTracker > clone() const
Handle table and batched decode orchestrator for branch management.
GrammarHandle create_grammar_lazy(const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const char *root="root")
Create a lazy grammar (unconstrained until trigger fires)
float get_sampling_ppl(MetricsHandle h) const
Get sampling-level perplexity from a metrics tracker.
bool isActive(BranchHandle h) const
Test whether a branch holds a KV lease.
void free_grammar(GrammarHandle h)
Free a grammar.
GrammarHandle create_grammar(const llama_model *model, const char *grammar_str, const char *root="root")
Create a grammar sampler and register it.
MetricsHandle clone_metrics(MetricsHandle h)
Clone a metrics tracker (for fork)
size_t available() const
Number of vacant seq_ids available for acquisition.
void release(BranchHandle handle)
Release a branch slot + evict its KV lease.
void drain()
Explicit teardown — evict all leases while context is alive.
float get_model_ppl(MetricsHandle h) const
Get model-level perplexity from a metrics tracker.
bool isLeaf(BranchHandle h) const
Test whether a branch is a leaf (no children)
void decode_each(std::span< const DecodeEachItem > items)
Decode one token per branch in a single GPU dispatch.
bool sampler_has_dist(SamplerChainHandle h) const
Check if a sampler chain ends with dist (stochastic) or greedy.
MetricsHandle create_metrics()
Create a metrics tracker and register it.
~BranchStore()
Destructor — frees CPU resources.
SamplerChainHandle clone_sampler(SamplerChainHandle h)
Clone a sampler chain (for fork)
llama_pos fork_head(BranchHandle h) const
Get a branch's fork head (parent position at fork time)
Allocation allocate()
Allocate a branch slot + KV lease atomically.
void add_cells_used(uint32_t n)
Increment cells_used counter (for standalone prefill/step outside BranchStore methods)
const BranchState * get(BranchHandle handle) const
Look up branch state by handle.
void free_sampler(SamplerChainHandle h)
Free a sampler chain.
void retainOnly(BranchHandle winner)
Keep only the winner — nuclear KV + CPU cleanup.
KvPressure kv_pressure() const
KV cache pressure snapshot — O(1), no tree walking.
SamplerChainHandle create_sampler(const P ¶ms)
Create a sampler chain and register it.
llama_sampler * get_grammar_sampler(GrammarHandle h) const
Dereference a grammar handle (non-owning)
void add_sampling_surprisal(MetricsHandle h, float surprisal)
Add sampling-level surprisal to a metrics tracker.
void init_tenancy(llama_context *ctx)
Initialize KV tenancy after context creation.
GrammarHandle clone_grammar(GrammarHandle h)
Clone a grammar (for fork)
void decode_scatter(std::span< const DecodeScatterItem > items)
Decode variable token counts per branch with auto-chunking.
BranchStore(size_t initial_capacity=16)
Construct a branch store with initial slot capacity.
void free_metrics(MetricsHandle h)
Free a metrics tracker.
const std::vector< BranchHandle > & children(BranchHandle h) const
Get a branch's child handles.
BranchState * get(BranchHandle handle)
Look up branch state by handle.
BranchHandle parent(BranchHandle h) const
Get a branch's parent handle.
void add_model_surprisal(MetricsHandle h, float surprisal)
Add model-level surprisal to a metrics tracker.
llama_sampler * get_sampler_chain(SamplerChainHandle h) const
Dereference a sampler chain handle (non-owning)
~Branch()
Destructor — CASCADE prunes entire subtree.
void accept(llama_token token)
Accept a token — advance grammar, penalty window, and metrics.
Branch & operator=(const Branch &)=delete
int n_vocab() const
Vocabulary size.
void setGrammar(const char *grammar_str)
Replace grammar constraint (nullptr/empty to remove)
const std::vector< BranchHandle > & childHandles() const
Child branch handles (empty if leaf)
void setSamplerParams(const P ¶ms)
Replace sampler chain with new parameters (memoized)
BranchHandle handle() const
Underlying opaque handle (for interop with free functions)
const float * logits() const
Get the branch's captured logits snapshot.
bool is_eog(llama_token token) const
Check if a token is end-of-generation for this branch's model.
BranchHandle parentHandle() const
Parent branch handle, or INVALID_HANDLE if root.
Branch(BranchStore *store, BranchHandle handle)
llama_pos position() const
Current decode position (token count)
llama_token sample()
Sample a token from captured logits.
void step(llama_token token)
Decode one token and capture logits (generation step)
void pruneSubtree()
CASCADE prune — removes entire subtree.
float perplexity() const
Model-level perplexity (from raw logits, pre-filter)
Branch(const Branch &)=delete
bool isActive() const
True if this branch holds a KV lease.
llama_pos forkHead() const
Parent's position at fork time (0 for root branches)
void force_snapshot_logits()
Force-copy shared logits buffer into this branch's snapshot.
bool isLeaf() const
True if this branch has no children.
void prefill(const llama_token *tokens, size_t n)
Decode multiple tokens and capture logits atomically (prompt injection)
bool valid() const
True if this Branch holds a valid handle.
void prune()
RESTRICT prune (throws if children exist)
Branch fork()
Fork: allocates slot + lease, records topology edge.
Branch & operator=(Branch &&other) noexcept
static Branch create(llama_context *ctx, const llama_model *model, BranchStore &store, llama_pos start_pos, const P ¶ms, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Factory: allocates slot + lease from store.
Branch(Branch &&other) noexcept
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Batch Decoding Operations.
Grammar-Constrained Sampling.
constexpr llama_seq_id NO_LEASE
Sentinel value indicating a branch has no KV residency.
Zero-copy logits access with clear lifetime semantics.
Distribution Metrics for Test-Time Alignment.
void prefill(BranchHandle handle, const llama_token *tokens, size_t n_tokens, BranchStore &s)
Decode multiple tokens and capture logits atomically (prompt prefill)
void set_logit_bias(BranchHandle handle, const llama_logit_bias *biases, size_t n_biases, BranchStore &s)
float get_perplexity(BranchHandle handle, BranchStore &s)
Get model-level perplexity (from raw logits)
void apply_grammar(BranchHandle handle, float *logits, int n_vocab, BranchStore &s)
Apply grammar constraints to an external logits buffer.
int32_t GrammarHandle
Handle to a grammar sampler in BranchStore's registry (0 = invalid/none)
uint32_t BranchHandle
Opaque handle to a branch slot.
void step(BranchHandle handle, llama_token token, BranchStore &s)
Decode a single token and capture logits (generation step)
constexpr int DEFAULT_N_BATCH
Default batch size for decode operations.
void set_grammar_lazy(BranchHandle handle, const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, BranchStore &s)
Set lazy grammar on a branch (unconstrained until trigger fires)
void set_steer(BranchHandle handle, std::function< void(llama_token_data_array &)> steer_fn, BranchStore &s)
constexpr uint32_t GEN_SHIFT
Bit shift for generation field.
void accept_token(BranchHandle handle, llama_token token, BranchStore &s)
Accept a sampled token, advancing grammar and sampler state.
void force_snapshot_logits(BranchHandle handle, BranchStore &s)
Force-copy the shared llama.cpp logits buffer into this branch's private snapshot.
llama_pos get_fork_head(BranchHandle handle, BranchStore &s)
Get the branch's fork head (parent position at fork time)
float get_sampling_perplexity(BranchHandle handle, BranchStore &s)
Get sampling-level perplexity (from filtered distribution)
float get_legal_logsumexp(BranchHandle handle, BranchStore &s)
Compute log-sum-exp over grammar-legal logits.
void clear_logit_bias(BranchHandle handle, BranchStore &s)
Clear all logit biases from a branch.
uint16_t handle_generation(BranchHandle h)
Extract generation counter from a branch handle.
const float * get_logits(BranchHandle handle, BranchStore &s)
Get the branch's captured logits snapshot.
void set_sampler_params(BranchHandle handle, const P ¶ms, BranchStore &s)
Replace a branch's sampler chain with new parameters.
void prune(BranchHandle handle, BranchStore &s)
Prune a leaf branch (RESTRICT — throws if children exist)
BranchHandle create(llama_context *ctx, const llama_model *model, BranchStore &s, llama_pos start_pos, const P ¶ms, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Create a new branch with sampler chain, optional grammar, and metrics.
CachedSamplingParams snapshot_params(const P &p)
Snapshot sampling params for memoization comparison.
constexpr uint32_t INDEX_MASK
Mask for slot index field.
BranchHandle fork(BranchHandle source, BranchStore &s)
Fork a branch into a new independent sequence.
int get_n_vocab(BranchHandle handle, BranchStore &s)
Get the branch's vocabulary size.
llama_token sample(BranchHandle handle, BranchStore &s)
Sample a token from the branch's captured logits.
void pruneSubtree(BranchHandle h, BranchStore &s)
Prune a branch and all descendants (CASCADE — iterative post-order)
constexpr BranchHandle INVALID_HANDLE
Null handle sentinel.
float get_token_prior(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token, checking grammar legality first.
constexpr llama_seq_id NO_LEASE
Branch has no KV residency.
int32_t MetricsHandle
Handle to a metrics tracker in BranchStore's registry (0 = invalid/none)
llama_pos get_position(BranchHandle handle, BranchStore &s)
Get the branch's current decode position.
void set_grammar(BranchHandle handle, const llama_model *model, const char *grammar_str, BranchStore &s)
Replace a branch's grammar constraint.
void clear_steer(BranchHandle handle, BranchStore &s)
Clear the steer callback from a branch.
float get_token_prior_assume_legal(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token known to be grammar-legal.
BranchHandle make_handle(uint16_t index, uint16_t generation)
Construct a branch handle from index and generation.
uint16_t handle_index(BranchHandle h)
Extract slot index from a branch handle.
int32_t SamplerChainHandle
Handle to a sampler chain in BranchStore's registry (0 = invalid/none)
std::vector< std::pair< llama_token, float > > get_legal_priors(BranchHandle handle, BranchStore &s)
Get grammar-legal tokens with renormalized probabilities.
bool is_token_legal(BranchHandle handle, llama_token token, BranchStore &s)
Check if a token is legal under grammar constraints.
float get_last_sampling_prior(BranchHandle handle, BranchStore &s)
Get the last sampled token's prior from the filtered distribution.
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
constexpr T as_value(const X &x, T def)
Extract value from either T or std::optional<T> with fallback.
void free_sampler(llama_sampler *smpl)
Free a grammar sampler.
llama_sampler * clone_sampler(llama_sampler *smpl)
Clone a grammar sampler (for fork/branching).
llama_sampler * init_sampler(const llama_model *model, const std::string &grammar_str, const std::string &root_rule="root")
Initialize a grammar sampler from GBNF grammar string.
llama_sampler * init_lazy_sampler(const llama_model *model, const std::string &grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const std::string &root_rule="root")
Initialize a lazy grammar sampler from GBNF grammar string.
void accept(llama_sampler *smpl, llama_token token)
Accept a token into grammar state.
void apply(llama_sampler *smpl, llama_token_data_array *cur_p)
Apply grammar constraint to candidates.
void evict_all(State &s)
Evict every leased seq_id.
llama_seq_id acquire(State &s)
Acquire a seq_id from the vacant pool.
size_t available(const State &s)
Number of vacant seq_ids available for acquisition.
void evict(State &s, llama_seq_id seq)
Evict a seq_id — strip all KV tags then release.
void retain(State &s, llama_seq_id keep)
Nuclear retain — keep one seq, rebuild vacancy from scratch.
State init(llama_context *ctx, llama_seq_id n_seq_max)
Initialize tenancy with all seq_ids vacant.
void release(State &s, llama_seq_id seq)
Release a seq_id back to vacant — bookkeeping only, no KV calls.
void seq_cp(llama_context *ctx, llama_seq_id src, llama_seq_id dst, llama_pos p0=0, llama_pos p1=-1)
Copy KV cache from one sequence to another.
llama_pos pos_max(llama_context *ctx, llama_seq_id seq)
Get maximum position in KV cache sequence.
float * get(llama_context *ctx, int32_t index=-1)
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
void apply(llama_sampler *chain, llama_token_data_array *cur_p)
Apply a sampler chain to a candidate array.
void accept(llama_sampler *chain, llama_token token)
Accept a token into the sampler chain.
llama_sampler * clone_chain(llama_sampler *chain)
Clone a sampler chain.
llama_sampler * create_chain(const P ¶ms)
Create a persistent sampler chain from parameters.
void free_chain(llama_sampler *chain)
Free a sampler chain.
bool is_eog(const llama_vocab *vocab, llama_token token)
Check if token is end-of-generation marker.
Token Sampling Operations.
Consolidated mutable state for a single branch.
bool in_use
True when slot is allocated to an active branch.
const llama_model * model
Llama model (not owned, must outlive branch)
bool has_logits
True only after force_snapshot_logits(), prefill(), or step()
std::function< void(llama_token_data_array &)> steer_fn
Dynamic logit callback, NOT cloned on fork.
int n_batch
Batch size for decode operations.
llama_seq_id seq_id
KV cache sequence identifier (NO_LEASE when inactive)
llama_pos position
Current decode position in the sequence.
std::vector< llama_logit_bias > logit_bias
Static token biases, cloned on fork.
std::vector< llama_token_data > last_candidates
Filtered candidates from last sample()
std::vector< float > logits_snapshot
Captured logit distribution (n_vocab floats)
std::vector< llama_token_data > candidates_buffer
Reusable scratch buffer for sampling (avoids O(n_vocab) allocs per sample call).
boundaries::BoundaryTracker * boundary_tracker
Token boundary detector (owned, optional)
SamplerChainHandle sampler_chain
Handle into BranchStore's sampler registry.
llama_context * ctx
Llama context (not owned, must outlive branch)
BranchHandle parent
Parent branch (INVALID_HANDLE if root)
int n_vocab
Vocabulary size (cached for buffer pre-allocation)
GrammarHandle grammar
Handle into BranchStore's grammar registry.
std::vector< BranchHandle > children
Child branches forked from this one.
llama_token last_token
Last token returned by sample()
MetricsHandle metrics
Handle into BranchStore's metrics registry.
llama_pos fork_head
Parent's position at fork time (0 for root branches)
CachedSamplingParams cached_params
Params used to create current chain (for memoization)
uint16_t generation
Slot generation counter (for ABA prevention)
Result of allocate(): a slot handle + its leased seq_id.
Concrete sampling params snapshot for memoization.
bool operator==(const CachedSamplingParams &) const =default
Item for decode_each: one token per branch.
Item for decode_scatter: variable tokens per branch.
std::span< const llama_token > tokens
RAII entry for a grammar sampler in the registry.
GrammarEntry(GrammarEntry &&o) noexcept
GrammarEntry & operator=(const GrammarEntry &)=delete
GrammarEntry & operator=(GrammarEntry &&o) noexcept
GrammarEntry(const GrammarEntry &)=delete
Snapshot of KV cache pressure from BranchStore.
uint32_t remaining
n_ctx - cells_used (clamped to 0)
uint32_t n_ctx
Total KV capacity.
uint32_t cells_used
Cells allocated since last reset.
RAII entry for a sampler chain in the registry.
SamplerChainEntry()=default
SamplerChainEntry(const SamplerChainEntry &)=delete
SamplerChainEntry & operator=(SamplerChainEntry &&o) noexcept
SamplerChainEntry(SamplerChainEntry &&o) noexcept
bool has_dist
True if chain ends with dist (temp > 0), false if greedy.
SamplerChainEntry & operator=(const SamplerChainEntry &)=delete
llama_context * ctx
Context for KV operations (nullptr after drain)
Unified model + sampling perplexity tracker.