liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
branch.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
56#include "boundaries.hpp"
57#include "common.hpp"
58#include "decode.hpp"
59#include "grammar.hpp"
60#include "kv.hpp"
61#include "logits.hpp"
62#include "metrics.hpp"
63#include "sampler.hpp"
64
65#include <llama/llama.h>
66#include <algorithm> // std::remove
67#include <cassert> // assert
68#include <cmath> // std::exp, std::log, std::isinf, std::isfinite
69#include <cstdint>
70#include <cstring> // std::memcpy
71#include <ctime> // std::time
72#include <deque> // std::deque (pointer stability for BranchStore)
73#include <functional> // std::function
74#include <limits> // std::numeric_limits
75#include <mutex>
76#include <span> // std::span (C++20)
77#include <stdexcept> // std::runtime_error
78#include <string> // std::to_string
79#include <utility> // std::pair, std::exchange
80#include <vector>
81
82namespace lloyal::branch {
83
84// ===== HANDLE TYPE =====
85
94using BranchHandle = uint32_t;
95
97constexpr llama_seq_id NO_LEASE = kv::NO_LEASE;
98constexpr int DEFAULT_N_BATCH = 512;
99constexpr uint32_t GEN_SHIFT = 16;
100constexpr uint32_t INDEX_MASK = 0xFFFF;
101
115 uint32_t n_ctx;
116 uint32_t cells_used;
117 uint32_t remaining;
118};
119
125inline uint16_t handle_index(BranchHandle h) {
126 return static_cast<uint16_t>(h & INDEX_MASK);
127}
128
134inline uint16_t handle_generation(BranchHandle h) {
135 return static_cast<uint16_t>(h >> GEN_SHIFT);
136}
137
144inline BranchHandle make_handle(uint16_t index, uint16_t generation) {
145 return (static_cast<uint32_t>(generation) << GEN_SHIFT) | index;
146}
147
148// ===== REGISTRY HANDLE TYPES =====
149
151using SamplerChainHandle = int32_t;
152
154using GrammarHandle = int32_t;
155
157using MetricsHandle = int32_t;
158
166 llama_sampler* chain = nullptr;
167 bool has_dist = false;
168
169 SamplerChainEntry() = default;
171
173 : chain(o.chain), has_dist(o.has_dist) { o.chain = nullptr; }
175 if (this != &o) {
177 chain = o.chain; has_dist = o.has_dist; o.chain = nullptr;
178 }
179 return *this;
180 }
183};
184
191 llama_sampler* sampler = nullptr;
192
193 GrammarEntry() = default;
195
196 GrammarEntry(GrammarEntry&& o) noexcept : sampler(o.sampler) { o.sampler = nullptr; }
198 if (this != &o) {
200 sampler = o.sampler; o.sampler = nullptr;
201 }
202 return *this;
203 }
204 GrammarEntry(const GrammarEntry&) = delete;
206};
207
216 float temperature = 0.8f;
217 int32_t top_k = 40;
218 float top_p = 0.95f;
219 float typical_p = 1.0f;
220 float min_p = 0.05f;
221 float penalty_repeat = 1.0f;
222 float penalty_freq = 0.0f;
223 float penalty_present = 0.0f;
224 int32_t penalty_last_n = 64;
225 uint32_t seed = 0;
226 bool operator==(const CachedSamplingParams&) const = default;
227};
228
236template <SamplingParamsLike P>
238 using ::lloyal::detail::as_value;
240 as_value(p.temperature, 0.8f),
241 as_value(p.top_k, static_cast<int32_t>(40)),
242 as_value(p.top_p, 0.95f),
243 as_value(p.typical_p, 1.0f),
244 as_value(p.min_p, 0.05f),
245 as_value(p.penalty_repeat, 1.0f),
246 as_value(p.penalty_freq, 0.0f),
247 as_value(p.penalty_present, 0.0f),
248 as_value(p.penalty_last_n, static_cast<int32_t>(64)),
249 as_value(p.seed, static_cast<uint32_t>(0)),
250 };
251}
252
253// ===== BRANCH STATE =====
254
276 llama_context* ctx = nullptr;
277 const llama_model* model = nullptr;
278
279 llama_seq_id seq_id = NO_LEASE;
280 llama_pos position = 0;
281 llama_pos fork_head = 0;
282
285
287
289
290 std::vector<llama_logit_bias> logit_bias;
291 std::function<void(llama_token_data_array&)> steer_fn;
292
294
295 llama_token last_token = -1;
296 std::vector<llama_token_data> last_candidates;
297
298 std::vector<float> logits_snapshot;
299 bool has_logits = false;
300
309 std::vector<llama_token_data> candidates_buffer;
310
312 int n_vocab = 0;
313
314 uint16_t generation = 0;
315 bool in_use = false;
316
317 // Topology — maintained by fork/prune/pruneSubtree
319 std::vector<BranchHandle> children;
320};
321
322// ===== BATCHED DECODE ITEM TYPES =====
323
334
350 std::span<const llama_token> tokens;
351};
352
353// ===== BRANCH STORE (HANDLE TABLE) =====
354
393public:
398 explicit BranchStore(size_t initial_capacity = 16) {
399 // Ensure minimum capacity of 2 (slot 0 reserved + at least 1 usable)
400 if (initial_capacity < 2) {
401 initial_capacity = 2;
402 }
403 slots_.resize(initial_capacity);
404 // Slot 0 is reserved (handle 0 = invalid)
405 slots_[0].in_use = true;
406 slots_[0].generation = 0xFFFF; // Never valid
407
408 // Initialize freelist with remaining slots
409 // NOTE: Use i-- > 1 pattern to avoid size_t underflow
410 for (size_t i = initial_capacity; i-- > 1; ) {
411 freelist_.push_back(static_cast<uint16_t>(i));
412 }
413 }
414
418 if (tenancy_.ctx != nullptr) {
419 LLOYAL_LOG_DEBUG("[BranchStore] WARNING: not drained before destruction");
420 }
421 for (size_t i = 1; i < slots_.size(); ++i) {
422 if (slots_[i].in_use) {
423 free_branch_resources(slots_[i]);
424 }
425 }
426 }
427
429 struct Allocation { BranchHandle handle; llama_seq_id seq_id; };
430
440 if (tenancy_.ctx == nullptr) return {INVALID_HANDLE, NO_LEASE}; // drained
441 llama_seq_id seq = kv::tenancy::acquire(tenancy_);
442 if (seq < 0) return {INVALID_HANDLE, NO_LEASE};
443 BranchHandle handle = allocate_slot();
444 if (handle == INVALID_HANDLE) {
445 // Seq was never used — bookkeeping-only return, no KV calls
446 kv::tenancy::release(tenancy_, seq);
447 return {INVALID_HANDLE, NO_LEASE};
448 }
449 // Stamp seq_id on slot so release() can evict properly
450 BranchState* st = get(handle);
451 st->seq_id = seq;
452 return {handle, seq};
453 }
454
463 void release(BranchHandle handle) {
464 if (handle == INVALID_HANDLE) return;
465 BranchState* st = get(handle);
466 if (!st) return;
467 // Eager edge cleanup: remove from parent's children
468 if (st->parent != INVALID_HANDLE) {
469 BranchState* p = get(st->parent);
470 if (p) {
471 auto& c = p->children;
472 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
473 }
474 }
475 // Subtract unique cells owned by this branch (above fork_head)
476 if (st->position > st->fork_head) {
477 uint32_t unique = static_cast<uint32_t>(st->position - st->fork_head);
478 cells_used_ = (unique <= cells_used_) ? cells_used_ - unique : 0;
479 }
480 // Evict lease (KV strip + bookkeeping)
481 if (st->seq_id != NO_LEASE)
482 kv::tenancy::evict(tenancy_, st->seq_id);
483 free_branch_resources(*st);
484 reset_slot(*st);
485 freelist_.push_back(handle_index(handle));
486
487 // All branches released → KV cache empty, reset pressure counter
488 if (freelist_.size() == slots_.size() - 1) {
489 cells_used_ = 0;
490 }
491 }
492
493 // ===== TENANCY LIFECYCLE =====
494
499 void init_tenancy(llama_context* ctx) {
500 tenancy_ = kv::tenancy::init(ctx, llama_n_seq_max(ctx));
501 cells_used_ = 0;
502 }
503
512 void drain() {
513 if (tenancy_.ctx == nullptr) return; // idempotent
514 kv::tenancy::evict_all(tenancy_);
515 for (size_t i = 1; i < slots_.size(); ++i) {
516 if (slots_[i].in_use) {
517 free_branch_resources(slots_[i]);
518 reset_slot(slots_[i]);
519 }
520 }
521 tenancy_.ctx = nullptr; // marks as drained
522 cells_used_ = 0;
523 }
524
535 BranchState* w = get(winner);
536 if (!w) throw std::runtime_error("retainOnly: invalid winner handle");
537 if (w->seq_id == NO_LEASE) throw std::runtime_error("retainOnly: winner has no lease");
538 kv::tenancy::retain(tenancy_, w->seq_id); // nuclear KV pass
539 // Collect losers first — don't mutate while iterating
540 std::vector<BranchHandle> losers;
541 for (size_t i = 1; i < slots_.size(); ++i) {
542 if (!slots_[i].in_use) continue;
543 BranchHandle h = make_handle(static_cast<uint16_t>(i), slots_[i].generation);
544 if (h == winner) continue;
545 losers.push_back(h);
546 }
547 for (auto h : losers)
548 release_slot_only(h); // CPU only, KV already stripped
550 w->fork_head = 0;
551 w->children.clear();
552 cells_used_ = static_cast<uint32_t>(w->position);
553 }
554
555 // ===== TOPOLOGY QUERIES =====
556
561 size_t available() const { return kv::tenancy::available(tenancy_); }
562
571 if (!tenancy_.ctx) return {0, 0, 0};
572 uint32_t n_ctx = static_cast<uint32_t>(llama_n_ctx(tenancy_.ctx));
573 uint32_t remaining = (n_ctx > cells_used_) ? n_ctx - cells_used_ : 0;
574 return { n_ctx, cells_used_, remaining };
575 }
576
578 void add_cells_used(uint32_t n) { cells_used_ += n; }
579
586 const BranchState* st = get(h);
587 return st ? st->parent : INVALID_HANDLE;
588 }
589
595 llama_pos fork_head(BranchHandle h) const {
596 const BranchState* st = get(h);
597 return st ? st->fork_head : 0;
598 }
599
605 const std::vector<BranchHandle>& children(BranchHandle h) const {
606 static const std::vector<BranchHandle> empty;
607 const BranchState* st = get(h);
608 return st ? st->children : empty;
609 }
610
616 bool isLeaf(BranchHandle h) const {
617 const BranchState* st = get(h);
618 return st ? st->children.empty() : true;
619 }
620
626 bool isActive(BranchHandle h) const {
627 const BranchState* st = get(h);
628 return st ? (st->seq_id != NO_LEASE) : false;
629 }
630
641 if (handle == INVALID_HANDLE) return nullptr;
642
643 uint16_t index = handle_index(handle);
644 uint16_t gen = handle_generation(handle);
645
646 // Slot 0 is reserved and never valid for external use
647 if (index == 0) return nullptr;
648
649 if (index >= slots_.size()) return nullptr;
650
651 BranchState& slot = slots_[index];
652 if (!slot.in_use || slot.generation != gen) {
653 return nullptr;
654 }
655
656 return &slot;
657 }
658
660 const BranchState* get(BranchHandle handle) const {
661 return const_cast<BranchStore*>(this)->get(handle);
662 }
663
664 // ===== SAMPLER CHAIN REGISTRY =====
665
671 template <SamplingParamsLike P>
673 SamplerChainHandle h = next_sampler_handle_++;
674 SamplerChainEntry entry;
675 entry.chain = sampler::create_chain(params);
676 float temperature = ::lloyal::detail::as_value(params.temperature, 0.8f);
677 entry.has_dist = (temperature > 0.0f);
678 sampler_chains_.emplace(h, std::move(entry));
679 return h;
680 }
681
688 if (h == 0) return 0;
689 auto it = sampler_chains_.find(h);
690 if (it == sampler_chains_.end()) return 0;
691 SamplerChainHandle nh = next_sampler_handle_++;
692 SamplerChainEntry entry;
693 entry.chain = sampler::clone_chain(it->second.chain);
694 entry.has_dist = it->second.has_dist;
695 sampler_chains_.emplace(nh, std::move(entry));
696 return nh;
697 }
698
704 if (h != 0) sampler_chains_.erase(h);
705 }
706
712 llama_sampler* get_sampler_chain(SamplerChainHandle h) const {
713 if (h == 0) return nullptr;
714 auto it = sampler_chains_.find(h);
715 return it != sampler_chains_.end() ? it->second.chain : nullptr;
716 }
717
724 if (h == 0) return false;
725 auto it = sampler_chains_.find(h);
726 return it != sampler_chains_.end() ? it->second.has_dist : false;
727 }
728
729 // ===== GRAMMAR REGISTRY =====
730
738 GrammarHandle create_grammar(const llama_model* model,
739 const char* grammar_str,
740 const char* root = "root") {
741 GrammarHandle h = next_grammar_handle_++;
742 GrammarEntry entry;
743 entry.sampler = grammar::init_sampler(model, grammar_str, root);
744 grammars_.emplace(h, std::move(entry));
745 return h;
746 }
747
758 const llama_model* model,
759 const char* grammar_str,
760 const std::vector<std::string>& trigger_patterns,
761 const std::vector<llama_token>& trigger_tokens,
762 const char* root = "root") {
763 GrammarHandle h = next_grammar_handle_++;
764 GrammarEntry entry;
766 model, grammar_str, trigger_patterns, trigger_tokens, root);
767 if (!entry.sampler) return 0;
768 grammars_.emplace(h, std::move(entry));
769 return h;
770 }
771
778 if (h == 0) return 0;
779 auto it = grammars_.find(h);
780 if (it == grammars_.end()) return 0;
781 GrammarHandle nh = next_grammar_handle_++;
782 GrammarEntry entry;
783 entry.sampler = grammar::clone_sampler(it->second.sampler);
784 grammars_.emplace(nh, std::move(entry));
785 return nh;
786 }
787
793 if (h != 0) grammars_.erase(h);
794 }
795
801 llama_sampler* get_grammar_sampler(GrammarHandle h) const {
802 if (h == 0) return nullptr;
803 auto it = grammars_.find(h);
804 return it != grammars_.end() ? it->second.sampler : nullptr;
805 }
806
807 // ===== METRICS REGISTRY =====
808
814 MetricsHandle h = next_metrics_handle_++;
815 metrics_registry_[h] = metrics::BranchMetricsState{};
816 return h;
817 }
818
825 if (h == 0) return 0;
826 auto it = metrics_registry_.find(h);
827 if (it == metrics_registry_.end()) return 0;
828 MetricsHandle nh = next_metrics_handle_++;
829 metrics_registry_[nh] = it->second;
830 return nh;
831 }
832
838 if (h != 0) metrics_registry_.erase(h);
839 }
840
846 void add_model_surprisal(MetricsHandle h, float surprisal) {
847 if (h == 0) return;
848 auto it = metrics_registry_.find(h);
849 if (it == metrics_registry_.end()) return;
850 if (!std::isfinite(surprisal)) return;
851 it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
852 it->second.model.count++;
853 }
854
860 void add_sampling_surprisal(MetricsHandle h, float surprisal) {
861 if (h == 0) return;
862 auto it = metrics_registry_.find(h);
863 if (it == metrics_registry_.end()) return;
864 if (!std::isfinite(surprisal)) return;
865 it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
866 it->second.sampling.count++;
867 }
868
875 if (h == 0) return std::numeric_limits<float>::infinity();
876 auto it = metrics_registry_.find(h);
877 if (it == metrics_registry_.end() || it->second.model.count == 0)
878 return std::numeric_limits<float>::infinity();
879 return std::exp(it->second.model.nll_sum_nats /
880 static_cast<float>(it->second.model.count));
881 }
882
889 if (h == 0) return std::numeric_limits<float>::infinity();
890 auto it = metrics_registry_.find(h);
891 if (it == metrics_registry_.end() || it->second.sampling.count == 0)
892 return std::numeric_limits<float>::infinity();
893 return std::exp(it->second.sampling.nll_sum_nats /
894 static_cast<float>(it->second.sampling.count));
895 }
896
897 // ===== BATCHED DECODE =====
898
921 void decode_each(std::span<const DecodeEachItem> items) {
922 if (items.empty()) return;
923
924 const int32_t n = static_cast<int32_t>(items.size());
925
926 // Validate handles and collect states
927 std::vector<BranchState*> states(n);
928 for (int32_t i = 0; i < n; ++i) {
929 states[i] = get(items[i].handle);
930 if (!states[i]) {
931 throw std::runtime_error("BranchStore::decode_each - invalid handle at index " + std::to_string(i));
932 }
933 if (i > 0 && states[i]->ctx != states[0]->ctx) {
934 throw std::runtime_error("BranchStore::decode_each - all branches must share the same context");
935 }
936 }
937
938 // Build EachItem array from branch states
939 std::vector<decode::EachItem> decode_items(n);
940 for (int32_t i = 0; i < n; ++i) {
941 decode_items[i].token = items[i].token;
942 decode_items[i].pos = states[i]->position;
943 decode_items[i].seq_id = states[i]->seq_id;
944 decode_items[i].output_logits = true;
945 }
946
947 // Single GPU dispatch
948 if (decode::each(states[0]->ctx, decode_items.data(), n, scratch_) != 0) {
949 throw std::runtime_error("BranchStore::decode_each - llama_decode failed");
950 }
951
952 // Capture logits and update positions
953 llama_context* ctx = states[0]->ctx;
954 for (int32_t i = 0; i < n; ++i) {
955 const float* raw_logits = logits::get(ctx, i); // throws on null
956 if (states[i]->n_vocab <= 0) {
957 throw std::runtime_error("BranchStore::decode_each - invalid vocab size at index " + std::to_string(i));
958 }
959 assert(states[i]->logits_snapshot.size() >= static_cast<size_t>(states[i]->n_vocab));
960 std::memcpy(states[i]->logits_snapshot.data(), raw_logits,
961 states[i]->n_vocab * sizeof(float));
962 states[i]->has_logits = true;
963 states[i]->position += 1;
964 }
965 cells_used_ += static_cast<uint32_t>(items.size());
966 }
967
993 void decode_scatter(std::span<const DecodeScatterItem> items) {
994 if (items.empty()) return;
995
996 const int32_t n = static_cast<int32_t>(items.size());
997
998 // Validate handles and collect states
999 std::vector<BranchState*> states(n);
1000 for (int32_t i = 0; i < n; ++i) {
1001 states[i] = get(items[i].handle);
1002 if (!states[i]) {
1003 throw std::runtime_error("BranchStore::decode_scatter - invalid handle at index " + std::to_string(i));
1004 }
1005 if (i > 0 && states[i]->ctx != states[0]->ctx) {
1006 throw std::runtime_error("BranchStore::decode_scatter - all branches must share the same context");
1007 }
1008 }
1009
1010 llama_context* ctx = states[0]->ctx;
1011 const int32_t batch_limit = static_cast<int32_t>(llama_n_batch(ctx));
1012
1013 // Build flat token spans — bin_pack skips empties internally
1014 std::vector<std::span<const llama_token>> spans(n);
1015 for (int32_t i = 0; i < n; ++i) {
1016 spans[i] = items[i].tokens;
1017 }
1018
1019 auto chunks = decode::bin_pack(spans.data(), n, batch_limit);
1020
1021 for (const auto& chunk : chunks) {
1022 if (chunk.oversized) {
1023 int32_t idx = chunk.indices[0];
1024 int32_t tc = static_cast<int32_t>(items[idx].tokens.size());
1025
1026 if (decode::many(ctx, items[idx].tokens.data(), tc,
1027 states[idx]->position, batch_limit,
1028 states[idx]->seq_id) != 0) {
1029 throw std::runtime_error("BranchStore::decode_scatter - decode::many failed for oversized item " + std::to_string(idx));
1030 }
1031
1032 const float* raw_logits = logits::get(ctx, -1);
1033 assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
1034 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1035 states[idx]->n_vocab * sizeof(float));
1036 states[idx]->has_logits = true;
1037 states[idx]->position += tc;
1038 continue;
1039 }
1040
1041 // Normal chunk — build ScatterItems and dispatch
1042 std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
1043 for (size_t k = 0; k < chunk.indices.size(); ++k) {
1044 int32_t idx = chunk.indices[k];
1045 scatter_items[k].tokens = items[idx].tokens;
1046 scatter_items[k].start_pos = states[idx]->position;
1047 scatter_items[k].seq_id = states[idx]->seq_id;
1048 scatter_items[k].output_logits = true;
1049 }
1050
1051 if (decode::scatter(ctx, scatter_items.data(),
1052 static_cast<int32_t>(scatter_items.size()),
1053 scratch_) != 0) {
1054 throw std::runtime_error("BranchStore::decode_scatter - decode::scatter failed");
1055 }
1056
1057 // Capture logits for each item in the chunk
1058 int32_t cursor = 0;
1059 for (size_t k = 0; k < scatter_items.size(); ++k) {
1060 int32_t idx = chunk.indices[k];
1061 int32_t item_n = static_cast<int32_t>(scatter_items[k].tokens.size());
1062
1063 const float* raw_logits = logits::get(ctx, cursor + item_n - 1);
1064 assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
1065 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1066 states[idx]->n_vocab * sizeof(float));
1067 states[idx]->has_logits = true;
1068 states[idx]->position += static_cast<int32_t>(items[idx].tokens.size());
1069
1070 cursor += item_n;
1071 }
1072 }
1073
1074 // Accumulate total tokens decoded across all items
1075 for (int32_t i = 0; i < n; ++i) {
1076 cells_used_ += static_cast<uint32_t>(items[i].tokens.size());
1077 }
1078 }
1079
1080private:
1082 void free_branch_resources(BranchState& slot) {
1083 if (slot.sampler_chain != 0) {
1085 slot.sampler_chain = 0;
1086 }
1087 if (slot.grammar != 0) {
1088 free_grammar(slot.grammar);
1089 slot.grammar = 0;
1090 }
1091 if (slot.boundary_tracker) {
1092 delete slot.boundary_tracker;
1093 slot.boundary_tracker = nullptr;
1094 }
1095 if (slot.metrics != 0) {
1096 free_metrics(slot.metrics);
1097 slot.metrics = 0;
1098 }
1099 }
1100
1102 void reset_slot(BranchState& slot) {
1103 slot.in_use = false;
1104 slot.generation = static_cast<uint16_t>(slot.generation + 1); // Prevent ABA
1105 slot.ctx = nullptr;
1106 slot.model = nullptr;
1107 slot.seq_id = NO_LEASE;
1108 slot.position = 0;
1109 slot.fork_head = 0;
1110 slot.sampler_chain = 0;
1111 slot.grammar = 0;
1112 slot.metrics = 0;
1113 slot.cached_params = CachedSamplingParams{};
1114 slot.last_token = -1;
1115 slot.last_candidates.clear();
1116 slot.logits_snapshot.clear();
1117 slot.has_logits = false;
1118 slot.logit_bias.clear();
1119 slot.steer_fn = nullptr;
1120 slot.candidates_buffer.clear();
1121 slot.n_batch = DEFAULT_N_BATCH;
1122 slot.n_vocab = 0;
1123 slot.parent = INVALID_HANDLE;
1124 slot.children.clear();
1125 }
1126
1128 BranchHandle allocate_slot() {
1129 if (freelist_.empty()) {
1130 size_t old_size = slots_.size();
1131 size_t new_size = old_size * 2;
1132 if (new_size > INDEX_MASK) {
1133 new_size = INDEX_MASK + 1;
1134 }
1135 if (old_size >= new_size) {
1136 LLOYAL_LOG_DEBUG("[branch::allocate_slot] Store full, cannot allocate");
1137 return INVALID_HANDLE;
1138 }
1139 slots_.resize(new_size);
1140 for (size_t i = new_size; i-- > old_size; ) {
1141 freelist_.push_back(static_cast<uint16_t>(i));
1142 }
1143 }
1144 uint16_t index = freelist_.back();
1145 freelist_.pop_back();
1146 BranchState& slot = slots_[index];
1147 slot.in_use = true;
1148 return make_handle(index, slot.generation);
1149 }
1150
1153 void release_slot_only(BranchHandle handle) {
1154 if (handle == INVALID_HANDLE) return;
1155 BranchState* st = get(handle);
1156 if (!st) return;
1157 // Eager edge cleanup
1158 if (st->parent != INVALID_HANDLE) {
1159 BranchState* p = get(st->parent);
1160 if (p) {
1161 auto& c = p->children;
1162 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
1163 }
1164 }
1165 free_branch_resources(*st);
1166 reset_slot(*st);
1167 freelist_.push_back(handle_index(handle));
1168 }
1169
1172 std::deque<BranchState> slots_;
1173 std::vector<uint16_t> freelist_;
1174
1178 kv::tenancy::State tenancy_;
1179
1183 uint32_t cells_used_ = 0;
1184
1187 decode::Scratch scratch_;
1188
1189 // ===== Handle registries (instance-scoped, not global static) =====
1190
1191 std::unordered_map<SamplerChainHandle, SamplerChainEntry> sampler_chains_;
1192 SamplerChainHandle next_sampler_handle_ = 1;
1193
1194 std::unordered_map<GrammarHandle, GrammarEntry> grammars_;
1195 GrammarHandle next_grammar_handle_ = 1;
1196
1197 std::unordered_map<MetricsHandle, metrics::BranchMetricsState> metrics_registry_;
1198 MetricsHandle next_metrics_handle_ = 1;
1199};
1200
1201
1202// ===== BRANCH API =====
1203
1224template <SamplingParamsLike P>
1226 llama_context* ctx,
1227 const llama_model* model,
1228 BranchStore& s,
1229 llama_pos start_pos,
1230 const P& params,
1231 int n_batch = DEFAULT_N_BATCH,
1232 const char* grammar_str = nullptr,
1233 boundaries::BoundaryTracker* boundary_tracker = nullptr) {
1234 if (!ctx || !model) {
1235 LLOYAL_LOG_DEBUG("[branch::create] NULL ctx or model");
1236 return INVALID_HANDLE;
1237 }
1238
1239 auto [handle, seq_id] = s.allocate();
1240 if (handle == INVALID_HANDLE) {
1241 return INVALID_HANDLE;
1242 }
1243
1244 BranchState* state = s.get(handle);
1245 if (!state) {
1246 s.release(handle);
1247 return INVALID_HANDLE;
1248 }
1249
1250 state->ctx = ctx;
1251 state->model = model;
1252 // seq_id already stamped by allocate()
1253 state->position = start_pos;
1254 state->n_batch = n_batch;
1255
1256 const llama_vocab* vocab = llama_model_get_vocab(model);
1257 state->n_vocab = llama_vocab_n_tokens(vocab);
1258 state->logits_snapshot.resize(state->n_vocab);
1259 state->has_logits = false;
1260 state->candidates_buffer.resize(state->n_vocab);
1261
1262 state->sampler_chain = s.create_sampler(params);
1263 state->cached_params = snapshot_params(params);
1264
1265 if (grammar_str && grammar_str[0] != '\0') {
1266 state->grammar = s.create_grammar(model, grammar_str);
1267 }
1268
1269 state->boundary_tracker = boundary_tracker;
1270 state->metrics = s.create_metrics();
1271
1272 LLOYAL_LOG_DEBUG("[branch::create] Created branch handle=%u seq=%d pos=%d",
1273 handle, seq_id, start_pos);
1274
1275 return handle;
1276}
1277
1278
1301 BranchState* src = s.get(source);
1302 if (!src) {
1303 LLOYAL_LOG_DEBUG("[branch::fork] Invalid source handle");
1304 return INVALID_HANDLE;
1305 }
1306
1307 auto [new_handle, new_seq_id] = s.allocate();
1308 if (new_handle == INVALID_HANDLE) {
1309 return INVALID_HANDLE;
1310 }
1311
1312 BranchState* dst = s.get(new_handle);
1313 if (!dst) {
1314 s.release(new_handle);
1315 return INVALID_HANDLE;
1316 }
1317
1318 // Copy basic state
1319 dst->ctx = src->ctx;
1320 dst->model = src->model;
1321 dst->seq_id = new_seq_id;
1322 dst->position = src->position;
1323 dst->fork_head = src->position;
1324 dst->n_batch = src->n_batch;
1325 dst->n_vocab = src->n_vocab;
1326
1327#ifndef NDEBUG
1328 assert(kv::pos_max(src->ctx, new_seq_id) < 0 && "tenancy: acquired seq must be clean");
1329 assert(dst->parent == INVALID_HANDLE && dst->children.empty() && "fresh slot must have no topology");
1330#endif
1331
1332 // Fork KV cache
1333 kv::seq_cp(src->ctx, src->seq_id, new_seq_id);
1334
1335 // Record topology
1336 dst->parent = source;
1337 src->children.push_back(new_handle);
1338
1339 // Clone sampler chain
1340 if (src->sampler_chain != 0) {
1342 }
1343 dst->cached_params = src->cached_params;
1344
1345 if (src->grammar != 0) {
1346 dst->grammar = s.clone_grammar(src->grammar);
1347 }
1348
1349 if (src->boundary_tracker) {
1350 dst->boundary_tracker = src->boundary_tracker->clone().release();
1351 }
1352
1353 if (src->metrics != 0) {
1354 dst->metrics = s.clone_metrics(src->metrics);
1355 }
1356
1357 dst->last_token = src->last_token;
1358 dst->last_candidates = src->last_candidates;
1359
1360 // logits_snapshot copy is intentional: fork's contract is "sample different
1361 // tokens from the same logit distribution." Without the copy, the child can't
1362 // sample without a redundant decode. Cost: n_vocab * 4 bytes (~512KB at 128k vocab).
1363 dst->logits_snapshot = src->logits_snapshot;
1364 dst->has_logits = src->has_logits;
1365 dst->logit_bias = src->logit_bias;
1366
1367 dst->candidates_buffer.resize(dst->n_vocab);
1368
1369 LLOYAL_LOG_DEBUG("[branch::fork] Forked handle=%u -> handle=%u seq=%d->%d",
1370 source, new_handle, src->seq_id, new_seq_id);
1371
1372 return new_handle;
1373}
1374
1395inline void set_logit_bias(
1396 BranchHandle handle,
1397 const llama_logit_bias* biases,
1398 size_t n_biases,
1399 BranchStore& s) {
1400
1401
1402 BranchState* state = s.get(handle);
1403 if (!state) {
1404 throw std::runtime_error("set_logit_bias: invalid branch handle");
1405 }
1406
1407 // Replace existing biases with new set
1408 state->logit_bias.assign(biases, biases + n_biases);
1409
1410 LLOYAL_LOG_DEBUG("[branch::set_logit_bias] Set %zu biases on handle=%u",
1411 n_biases, handle);
1412}
1413
1422 BranchHandle handle,
1423 BranchStore& s) {
1424
1425
1426 BranchState* state = s.get(handle);
1427 if (!state) {
1428 throw std::runtime_error("clear_logit_bias: invalid branch handle");
1429 }
1430
1431 state->logit_bias.clear();
1432
1433 LLOYAL_LOG_DEBUG("[branch::clear_logit_bias] Cleared biases on handle=%u", handle);
1434}
1435
1462inline void set_steer(
1463 BranchHandle handle,
1464 std::function<void(llama_token_data_array&)> steer_fn,
1465 BranchStore& s) {
1466
1467
1468 BranchState* state = s.get(handle);
1469 if (!state) {
1470 throw std::runtime_error("set_steer: invalid branch handle");
1471 }
1472
1473 state->steer_fn = std::move(steer_fn);
1474
1475 LLOYAL_LOG_DEBUG("[branch::set_steer] Set steer callback on handle=%u", handle);
1476}
1477
1485inline void clear_steer(
1486 BranchHandle handle,
1487 BranchStore& s) {
1488
1489
1490 BranchState* state = s.get(handle);
1491 if (!state) {
1492 throw std::runtime_error("clear_steer: invalid branch handle");
1493 }
1494
1495 state->steer_fn = nullptr;
1496
1497 LLOYAL_LOG_DEBUG("[branch::clear_steer] Cleared steer callback on handle=%u", handle);
1498}
1499
1515template <SamplingParamsLike P>
1516inline void set_sampler_params(BranchHandle handle, const P& params, BranchStore& s) {
1517 BranchState* state = s.get(handle);
1518 if (!state) {
1519 throw std::runtime_error("set_sampler_params: invalid branch handle");
1520 }
1521
1522 CachedSamplingParams new_params = snapshot_params(params);
1523 if (new_params == state->cached_params && state->sampler_chain != 0) {
1524 return; // Memoized — no rebuild needed
1525 }
1526
1527 // Free old chain
1528 if (state->sampler_chain != 0) {
1529 s.free_sampler(state->sampler_chain);
1530 }
1531
1532 // Create new chain
1533 state->sampler_chain = s.create_sampler(params);
1534 state->cached_params = new_params;
1535
1536 LLOYAL_LOG_DEBUG("[branch::set_sampler_params] Rebuilt chain on handle=%u temp=%.3f",
1537 handle, new_params.temperature);
1538}
1539
1552inline void set_grammar(
1553 BranchHandle handle,
1554 const llama_model* model,
1555 const char* grammar_str,
1556 BranchStore& s) {
1557 BranchState* state = s.get(handle);
1558 if (!state) {
1559 throw std::runtime_error("set_grammar: invalid branch handle");
1560 }
1561
1562 // Free old grammar
1563 if (state->grammar != 0) {
1564 s.free_grammar(state->grammar);
1565 state->grammar = 0;
1566 }
1567
1568 // Create new grammar if provided
1569 if (grammar_str && grammar_str[0] != '\0') {
1570 state->grammar = s.create_grammar(model, grammar_str);
1571 }
1572
1573 LLOYAL_LOG_DEBUG("[branch::set_grammar] %s grammar on handle=%u",
1574 state->grammar != 0 ? "Set" : "Cleared", handle);
1575}
1576
1592 BranchHandle handle,
1593 const llama_model* model,
1594 const char* grammar_str,
1595 const std::vector<std::string>& trigger_patterns,
1596 const std::vector<llama_token>& trigger_tokens,
1597 BranchStore& s) {
1598 BranchState* state = s.get(handle);
1599 if (!state) {
1600 throw std::runtime_error("set_grammar_lazy: invalid branch handle");
1601 }
1602
1603 if (state->grammar != 0) {
1604 s.free_grammar(state->grammar);
1605 state->grammar = 0;
1606 }
1607
1608 if (grammar_str && grammar_str[0] != '\0') {
1609 state->grammar = s.create_grammar_lazy(
1610 model, grammar_str, trigger_patterns, trigger_tokens);
1611 }
1612
1613 LLOYAL_LOG_DEBUG("[branch::set_grammar_lazy] %s grammar on handle=%u",
1614 state->grammar != 0 ? "Set" : "Cleared", handle);
1615}
1616
1627inline void prune(BranchHandle handle, BranchStore& s) {
1628 BranchState* state = s.get(handle);
1629 if (!state) return;
1630 if (!state->children.empty())
1631 throw std::runtime_error("prune: RESTRICT — branch has children. Use pruneSubtree() for CASCADE.");
1632 s.release(handle);
1633}
1634
1645 std::vector<BranchHandle> stack{h}, post_order;
1646 while (!stack.empty()) {
1647 BranchHandle cur = stack.back(); stack.pop_back();
1648 post_order.push_back(cur);
1649 BranchState* st = s.get(cur);
1650 if (st) for (auto child : st->children) stack.push_back(child);
1651 }
1652 for (auto it = post_order.rbegin(); it != post_order.rend(); ++it)
1653 prune(*it, s);
1654}
1655
1676
1677
1678 BranchState* state = s.get(handle);
1679 if (!state) {
1680 throw std::runtime_error("force_snapshot_logits: invalid branch handle");
1681 }
1682
1683 // logits::get() throws if ctx is null or logits unavailable
1684 const float* raw_logits = logits::get(state->ctx, -1);
1685
1686 if (state->n_vocab <= 0) {
1687 throw std::runtime_error("force_snapshot_logits: invalid vocab size");
1688 }
1689
1690 std::memcpy(state->logits_snapshot.data(), raw_logits,
1691 state->n_vocab * sizeof(float));
1692 state->has_logits = true;
1693}
1694
1711inline void prefill(
1712 BranchHandle handle,
1713 const llama_token* tokens,
1714 size_t n_tokens,
1715 BranchStore& s) {
1716
1717
1718 BranchState* state = s.get(handle);
1719 if (!state) {
1720 throw std::runtime_error("prefill: invalid branch handle");
1721 }
1722
1723 // Pass raw pointer directly - no vector copy needed
1724 if (decode::many(state->ctx, tokens, static_cast<int32_t>(n_tokens),
1725 state->position, state->n_batch, state->seq_id) != 0) {
1726 throw std::runtime_error("prefill: llama_decode failed");
1727 }
1728
1729 state->position += static_cast<llama_pos>(n_tokens);
1730 s.add_cells_used(static_cast<uint32_t>(n_tokens));
1731
1732 // logits::get() throws if logits unavailable
1733 const float* raw_logits = logits::get(state->ctx, -1);
1734
1735 if (state->n_vocab <= 0) {
1736 throw std::runtime_error("prefill: invalid vocab size");
1737 }
1738
1739 std::memcpy(state->logits_snapshot.data(), raw_logits,
1740 state->n_vocab * sizeof(float));
1741 state->has_logits = true;
1742}
1743
1756inline void step(
1757 BranchHandle handle,
1758 llama_token token,
1759 BranchStore& s) {
1760
1761
1762 BranchState* state = s.get(handle);
1763 if (!state) {
1764 throw std::runtime_error("step: invalid branch handle");
1765 }
1766
1767 if (decode::one(state->ctx, token, state->position, state->seq_id, true) != 0) {
1768 throw std::runtime_error("step: llama_decode failed");
1769 }
1770 state->position += 1;
1771 s.add_cells_used(1);
1772
1773 // logits::get() throws if logits unavailable
1774 const float* raw_logits = logits::get(state->ctx, -1);
1775
1776 if (state->n_vocab <= 0) {
1777 throw std::runtime_error("step: invalid vocab size");
1778 }
1779
1780 std::memcpy(state->logits_snapshot.data(), raw_logits,
1781 state->n_vocab * sizeof(float));
1782 state->has_logits = true;
1783}
1784
1795inline const float* get_logits(BranchHandle handle, BranchStore& s) {
1796
1797
1798 const BranchState* state = s.get(handle);
1799 // Must check has_logits, not just empty() - buffer is pre-allocated in create()
1800 if (!state || !state->has_logits) {
1801 return nullptr;
1802 }
1803
1804 return state->logits_snapshot.data();
1805}
1806
1821inline llama_token sample(BranchHandle handle, BranchStore& s) {
1822
1823
1824 BranchState* state = s.get(handle);
1825 llama_sampler* chain = state ? s.get_sampler_chain(state->sampler_chain) : nullptr;
1826 if (!state || !chain) {
1827 return -1;
1828 }
1829
1830 // Must have logits captured before sampling
1831 if (!state->has_logits) {
1832 LLOYAL_LOG_DEBUG("[branch::sample] No logits captured - call prefill()/step() first");
1833 return -1;
1834 }
1835
1836 // Reuse pre-allocated candidates buffer (avoids O(n_vocab) allocs per sample)
1837 for (int i = 0; i < state->n_vocab; i++) {
1838 state->candidates_buffer[i] = llama_token_data{
1839 static_cast<llama_token>(i),
1840 state->logits_snapshot[i],
1841 0.0f};
1842 }
1843
1844 llama_token_data_array cur_p = {
1845 state->candidates_buffer.data(),
1846 static_cast<size_t>(state->n_vocab),
1847 -1,
1848 false};
1849
1850 // Apply grammar first if present (via anti-corruption layer)
1851 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
1852 if (gram) {
1853 grammar::apply(gram, &cur_p);
1854 }
1855
1856 // Apply logit bias if present — O(n_biases) via direct index
1857 // (candidates are in token-ID order; grammar::apply preserves order)
1858 if (!state->logit_bias.empty()) {
1859 for (const auto& bias : state->logit_bias) {
1860 if (bias.token >= 0 && bias.token < state->n_vocab) {
1861 cur_p.data[bias.token].logit += bias.bias;
1862 }
1863 }
1864 }
1865
1866 // Apply steer callback if present
1867 if (state->steer_fn) {
1868 try {
1869 state->steer_fn(cur_p);
1870 } catch (const std::exception& e) {
1871 LLOYAL_LOG_DEBUG("[branch::sample] Steer exception: %s", e.what());
1872 // Continue sampling without steer on exception
1873 }
1874 }
1875
1876 // Apply sampler chain (via anti-corruption layer)
1877 sampler::apply(chain, &cur_p);
1878
1879 if (cur_p.selected == -1) {
1880 return -1;
1881 }
1882
1883 llama_token token = cur_p.data[cur_p.selected].id;
1884
1885 // Capture filtered candidates for sampling metrics
1886 // After BOTH grammar and sampler chain - this is the actual sampling distribution
1887 state->last_token = token;
1888 state->last_candidates.clear();
1889 state->last_candidates.reserve(cur_p.size);
1890
1891 for (size_t i = 0; i < cur_p.size; i++) {
1892 state->last_candidates.push_back(cur_p.data[i]);
1893 }
1894
1895 return token;
1896}
1897
1913inline void accept_token(
1914 BranchHandle handle,
1915 llama_token token,
1916 BranchStore& s) {
1917
1918
1919 BranchState* state = s.get(handle);
1920 if (!state) return;
1921
1922 // Accept in grammar (via anti-corruption layer)
1923 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
1924 if (gram) {
1925 grammar::accept(gram, token);
1926 }
1927
1928 // Accept in sampler chain for penalty tracking (via anti-corruption layer)
1929 llama_sampler* chain = s.get_sampler_chain(state->sampler_chain);
1930 if (chain) {
1931 sampler::accept(chain, token);
1932 }
1933
1934 // Update model-level perplexity (from raw logits)
1935 // Guard on has_logits to avoid computing surprisal from zero-filled buffer
1936 if (state->metrics != 0 && state->has_logits) {
1937 float ms = metrics::model_surprisal(
1938 state->logits_snapshot.data(), state->n_vocab, token);
1939 s.add_model_surprisal(state->metrics, ms);
1940 }
1941
1942 // Update sampling-level perplexity (from filtered candidates)
1943 if (state->metrics != 0 && !state->last_candidates.empty() &&
1944 token == state->last_token) {
1945 // Extract filtered logits and IDs from candidates
1946 std::vector<float> candidate_logits;
1947 std::vector<int32_t> candidate_ids;
1948 candidate_logits.reserve(state->last_candidates.size());
1949 candidate_ids.reserve(state->last_candidates.size());
1950
1951 for (const auto& cand : state->last_candidates) {
1952 candidate_logits.push_back(cand.logit);
1953 candidate_ids.push_back(cand.id);
1954 }
1955
1956 // Compute sampling-level surprisal
1957 float ss = metrics::sampling_surprisal(
1958 candidate_logits.data(),
1959 candidate_ids.data(),
1960 static_cast<int>(candidate_logits.size()),
1961 token
1962 );
1963 s.add_sampling_surprisal(state->metrics, ss);
1964 }
1965}
1966
1981inline void apply_grammar(
1982 BranchHandle handle,
1983 float* logits,
1984 int n_vocab,
1985 BranchStore& s) {
1986
1987
1988 BranchState* state = s.get(handle);
1989 llama_sampler* gram = state ? s.get_grammar_sampler(state->grammar) : nullptr;
1990 if (!state || !gram) return;
1991
1992 // Use pre-allocated candidates buffer if size matches, otherwise allocate
1993 std::vector<llama_token_data>* candidates_ptr;
1994 std::vector<llama_token_data> temp_buffer;
1995
1996 if (n_vocab == state->n_vocab && !state->candidates_buffer.empty()) {
1997 candidates_ptr = &state->candidates_buffer;
1998 } else {
1999 temp_buffer.resize(n_vocab);
2000 candidates_ptr = &temp_buffer;
2001 }
2002
2003 auto& candidates = *candidates_ptr;
2004 for (int i = 0; i < n_vocab; i++) {
2005 candidates[i] = llama_token_data{
2006 static_cast<llama_token>(i), logits[i], 0.0f};
2007 }
2008
2009 llama_token_data_array cur_p = {
2010 candidates.data(),
2011 static_cast<size_t>(n_vocab),
2012 -1,
2013 false};
2014
2015 grammar::apply(gram, &cur_p);
2016
2017 // Copy masked logits back
2018 for (int i = 0; i < n_vocab; i++) {
2019 logits[i] = candidates[i].logit;
2020 }
2021}
2022
2037inline std::vector<std::pair<llama_token, float>> get_legal_priors(
2038 BranchHandle handle,
2039 BranchStore& s) {
2040
2041
2042 BranchState* state = s.get(handle);
2043 if (!state || !state->has_logits) {
2044 return {};
2045 }
2046
2047 // Reuse pre-allocated candidates buffer
2048 for (int i = 0; i < state->n_vocab; i++) {
2049 state->candidates_buffer[i] = llama_token_data{
2050 static_cast<llama_token>(i),
2051 state->logits_snapshot[i],
2052 0.0f};
2053 }
2054
2055 llama_token_data_array cur_p = {
2056 state->candidates_buffer.data(),
2057 static_cast<size_t>(state->n_vocab),
2058 -1,
2059 false};
2060
2061 // Apply grammar to mask illegal tokens
2062 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2063 if (gram) {
2064 grammar::apply(gram, &cur_p);
2065 }
2066
2067 // Collect legal candidates (logit is finite after grammar masking)
2068 // Grammar masking sets illegal tokens to -INFINITY
2069 std::vector<std::pair<llama_token, float>> legal_priors;
2070 float max_logit = -std::numeric_limits<float>::infinity();
2071
2072 for (size_t i = 0; i < cur_p.size; i++) {
2073 if (std::isfinite(cur_p.data[i].logit)) { // Not masked
2074 legal_priors.emplace_back(cur_p.data[i].id, cur_p.data[i].logit);
2075 if (cur_p.data[i].logit > max_logit) {
2076 max_logit = cur_p.data[i].logit;
2077 }
2078 }
2079 }
2080
2081 if (legal_priors.empty()) {
2082 return {};
2083 }
2084
2085 // Compute softmax over legal moves only (numerically stable)
2086 float sum_exp = 0.0f;
2087 for (auto& [token, logit] : legal_priors) {
2088 float exp_val = std::exp(logit - max_logit);
2089 logit = exp_val; // Temporarily store exp value
2090 sum_exp += exp_val;
2091 }
2092
2093 // Normalize to probabilities
2094 for (auto& [token, prob] : legal_priors) {
2095 prob /= sum_exp;
2096 }
2097
2098 return legal_priors;
2099}
2100
2117
2118
2119 BranchState* state = s.get(handle);
2120 if (!state || !state->has_logits) {
2121 return -std::numeric_limits<float>::infinity();
2122 }
2123
2124 // Reuse pre-allocated candidates buffer
2125 for (int i = 0; i < state->n_vocab; i++) {
2126 state->candidates_buffer[i] = llama_token_data{
2127 static_cast<llama_token>(i),
2128 state->logits_snapshot[i],
2129 0.0f};
2130 }
2131
2132 llama_token_data_array cur_p = {
2133 state->candidates_buffer.data(),
2134 static_cast<size_t>(state->n_vocab),
2135 -1,
2136 false};
2137
2138 // Apply grammar to mask illegal tokens
2139 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2140 if (gram) {
2141 grammar::apply(gram, &cur_p);
2142 }
2143
2144 // Numerically stable logsumexp over legal tokens (finite logits only)
2145 float max_logit = -std::numeric_limits<float>::infinity();
2146 for (size_t i = 0; i < cur_p.size; i++) {
2147 if (std::isfinite(cur_p.data[i].logit) && cur_p.data[i].logit > max_logit) {
2148 max_logit = cur_p.data[i].logit;
2149 }
2150 }
2151
2152 if (!std::isfinite(max_logit)) {
2153 return -std::numeric_limits<float>::infinity(); // No legal tokens
2154 }
2155
2156 float sum_exp = 0.0f;
2157 for (size_t i = 0; i < cur_p.size; i++) {
2158 if (std::isfinite(cur_p.data[i].logit)) {
2159 sum_exp += std::exp(cur_p.data[i].logit - max_logit);
2160 }
2161 }
2162
2163 return max_logit + std::log(sum_exp);
2164}
2165
2177inline bool is_token_legal(
2178 BranchHandle handle,
2179 llama_token token,
2180 BranchStore& s) {
2181
2182
2183 BranchState* state = s.get(handle);
2184 if (!state || token < 0 || token >= state->n_vocab) {
2185 return false;
2186 }
2187
2188 // No grammar = all tokens legal
2189 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2190 if (!gram) {
2191 return true;
2192 }
2193
2194 // Build 1-element candidate array (stack allocated, no heap)
2195 llama_token_data single_candidate = {
2196 token,
2197 state->has_logits ? state->logits_snapshot[token] : 0.0f,
2198 0.0f
2199 };
2200
2201 llama_token_data_array cur_p = {
2202 &single_candidate,
2203 1,
2204 -1,
2205 false
2206 };
2207
2208 // Apply grammar - will set logit to -INFINITY if illegal
2209 grammar::apply(gram, &cur_p);
2210
2211 return std::isfinite(single_candidate.logit);
2212}
2213
2229 BranchHandle handle,
2230 llama_token token,
2231 float logsumexp,
2232 BranchStore& s) {
2233
2234
2235 BranchState* state = s.get(handle);
2236 if (!state || !state->has_logits || token < 0 || token >= state->n_vocab) {
2237 return 0.0f;
2238 }
2239
2240 float logit = state->logits_snapshot[token];
2241 return std::exp(logit - logsumexp);
2242}
2243
2259inline float get_token_prior(
2260 BranchHandle handle,
2261 llama_token token,
2262 float logsumexp,
2263 BranchStore& s) {
2264 if (!is_token_legal(handle, token, s)) {
2265 return 0.0f;
2266 }
2267 return get_token_prior_assume_legal(handle, token, logsumexp, s);
2268}
2269
2270// ===== STATE ACCESSORS =====
2271
2272
2279inline llama_pos get_position(BranchHandle handle, BranchStore& s) {
2280
2281 const BranchState* state = s.get(handle);
2282 return state ? state->position : -1;
2283}
2284
2291inline llama_pos get_fork_head(BranchHandle handle, BranchStore& s) {
2292 const BranchState* state = s.get(handle);
2293 return state ? state->fork_head : 0;
2294}
2295
2307inline float get_perplexity(BranchHandle handle, BranchStore& s) {
2308
2309 const BranchState* state = s.get(handle);
2310 if (!state || state->metrics == 0) {
2311 return std::numeric_limits<float>::infinity();
2312 }
2313 return s.get_model_ppl(state->metrics);
2314}
2315
2328
2329 const BranchState* state = s.get(handle);
2330 if (!state || state->metrics == 0) {
2331 return std::numeric_limits<float>::infinity();
2332 }
2333 return s.get_sampling_ppl(state->metrics);
2334}
2335
2347
2348 const BranchState* state = s.get(handle);
2349
2350 if (!state || state->last_candidates.empty() || state->last_token < 0) {
2351 return 0.0f;
2352 }
2353
2354 // Extract candidates
2355 std::vector<float> candidate_logits;
2356 std::vector<int32_t> candidate_ids;
2357 candidate_logits.reserve(state->last_candidates.size());
2358 candidate_ids.reserve(state->last_candidates.size());
2359
2360 for (const auto& cand : state->last_candidates) {
2361 candidate_logits.push_back(cand.logit);
2362 candidate_ids.push_back(cand.id);
2363 }
2364
2365 // Compute surprisal from filtered distribution
2366 float surprisal = metrics::sampling_surprisal(
2367 candidate_logits.data(),
2368 candidate_ids.data(),
2369 static_cast<int>(candidate_logits.size()),
2370 state->last_token
2371 );
2372
2373 // Convert to probability: P = exp(-surprisal)
2374 return std::exp(-surprisal);
2375}
2376
2383inline int get_n_vocab(BranchHandle handle, BranchStore& s) {
2384
2385 const BranchState* state = s.get(handle);
2386 return state ? state->n_vocab : 0;
2387}
2388
2389
2390// ===== RAII WRAPPER =====
2391
2409class Branch {
2410public:
2411 Branch() : store_(nullptr), handle_(INVALID_HANDLE) {}
2412
2414 : store_(store), handle_(handle) {}
2415
2418 if (handle_ != INVALID_HANDLE && store_) {
2419 branch::pruneSubtree(handle_, *store_);
2420 }
2421 }
2422
2423 Branch(Branch&& other) noexcept
2424 : store_(other.store_), handle_(other.handle_) {
2425 other.handle_ = INVALID_HANDLE;
2426 }
2427
2428 Branch& operator=(Branch&& other) noexcept {
2429 if (this != &other) {
2430 if (handle_ != INVALID_HANDLE && store_) {
2431 branch::pruneSubtree(handle_, *store_);
2432 }
2433 store_ = other.store_;
2434 handle_ = other.handle_;
2435 other.handle_ = INVALID_HANDLE;
2436 }
2437 return *this;
2438 }
2439
2440 Branch(const Branch&) = delete;
2441 Branch& operator=(const Branch&) = delete;
2442
2444 template <SamplingParamsLike P>
2446 llama_context* ctx,
2447 const llama_model* model,
2448 BranchStore& store,
2449 llama_pos start_pos,
2450 const P& params,
2451 int n_batch = DEFAULT_N_BATCH,
2452 const char* grammar_str = nullptr,
2453 boundaries::BoundaryTracker* boundary_tracker = nullptr) {
2454 BranchHandle h = branch::create(ctx, model, store, start_pos, params, n_batch, grammar_str, boundary_tracker);
2455 return Branch(&store, h);
2456 }
2457
2460 BranchHandle h = branch::fork(handle_, *store_);
2461 return Branch(store_, h);
2462 }
2463
2465 void prune() {
2466 branch::prune(handle_, *store_);
2467 handle_ = INVALID_HANDLE;
2468 }
2469
2472 branch::pruneSubtree(handle_, *store_);
2473 handle_ = INVALID_HANDLE;
2474 }
2475
2479 branch::force_snapshot_logits(handle_, *store_);
2480 }
2481
2485 void prefill(const llama_token* tokens, size_t n) {
2486 branch::prefill(handle_, tokens, n, *store_);
2487 }
2488
2491 void step(llama_token token) {
2492 branch::step(handle_, token, *store_);
2493 }
2494
2497 const float* logits() const {
2498 return branch::get_logits(handle_, *store_);
2499 }
2500
2504 llama_token sample() {
2505 return branch::sample(handle_, *store_);
2506 }
2507
2510 void accept(llama_token token) {
2511 branch::accept_token(handle_, token, *store_);
2512 }
2513
2515 bool is_eog(llama_token token) const {
2516 const BranchState* st = store_ ? store_->get(handle_) : nullptr;
2517 return st && st->model ? tokenizer::is_eog(st->model, token) : false;
2518 }
2519
2522 template <SamplingParamsLike P>
2523 void setSamplerParams(const P& params) {
2524 branch::set_sampler_params(handle_, params, *store_);
2525 }
2526
2529 void setGrammar(const char* grammar_str) {
2530 const BranchState* st = store_ ? store_->get(handle_) : nullptr;
2531 branch::set_grammar(handle_, st ? st->model : nullptr, grammar_str, *store_);
2532 }
2533
2534 // ===== ACCESSORS =====
2535
2537 llama_pos position() const { return branch::get_position(handle_, *store_); }
2539 llama_pos forkHead() const { return branch::get_fork_head(handle_, *store_); }
2541 float perplexity() const { return branch::get_perplexity(handle_, *store_); }
2543 int n_vocab() const { return branch::get_n_vocab(handle_, *store_); }
2545 bool valid() const { return handle_ != INVALID_HANDLE; }
2547 BranchHandle handle() const { return handle_; }
2548
2549 // ===== TOPOLOGY =====
2550
2552 BranchHandle parentHandle() const { return store_ ? store_->parent(handle_) : INVALID_HANDLE; }
2554 const std::vector<BranchHandle>& childHandles() const {
2555 static const std::vector<BranchHandle> empty;
2556 return store_ ? store_->children(handle_) : empty;
2557 }
2559 bool isLeaf() const { return store_ ? store_->isLeaf(handle_) : true; }
2561 bool isActive() const { return store_ ? store_->isActive(handle_) : false; }
2562
2563private:
2564 BranchStore* store_;
2565 BranchHandle handle_;
2566};
2567
2568} // namespace lloyal::branch
Stub BoundaryTracker - does nothing.
virtual std::unique_ptr< BoundaryTracker > clone() const
Handle table and batched decode orchestrator for branch management.
Definition branch.hpp:392
GrammarHandle create_grammar_lazy(const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const char *root="root")
Create a lazy grammar (unconstrained until trigger fires)
Definition branch.hpp:757
float get_sampling_ppl(MetricsHandle h) const
Get sampling-level perplexity from a metrics tracker.
Definition branch.hpp:888
bool isActive(BranchHandle h) const
Test whether a branch holds a KV lease.
Definition branch.hpp:626
void free_grammar(GrammarHandle h)
Free a grammar.
Definition branch.hpp:792
GrammarHandle create_grammar(const llama_model *model, const char *grammar_str, const char *root="root")
Create a grammar sampler and register it.
Definition branch.hpp:738
MetricsHandle clone_metrics(MetricsHandle h)
Clone a metrics tracker (for fork)
Definition branch.hpp:824
size_t available() const
Number of vacant seq_ids available for acquisition.
Definition branch.hpp:561
void release(BranchHandle handle)
Release a branch slot + evict its KV lease.
Definition branch.hpp:463
void drain()
Explicit teardown — evict all leases while context is alive.
Definition branch.hpp:512
float get_model_ppl(MetricsHandle h) const
Get model-level perplexity from a metrics tracker.
Definition branch.hpp:874
bool isLeaf(BranchHandle h) const
Test whether a branch is a leaf (no children)
Definition branch.hpp:616
void decode_each(std::span< const DecodeEachItem > items)
Decode one token per branch in a single GPU dispatch.
Definition branch.hpp:921
bool sampler_has_dist(SamplerChainHandle h) const
Check if a sampler chain ends with dist (stochastic) or greedy.
Definition branch.hpp:723
MetricsHandle create_metrics()
Create a metrics tracker and register it.
Definition branch.hpp:813
~BranchStore()
Destructor — frees CPU resources.
Definition branch.hpp:417
SamplerChainHandle clone_sampler(SamplerChainHandle h)
Clone a sampler chain (for fork)
Definition branch.hpp:687
llama_pos fork_head(BranchHandle h) const
Get a branch's fork head (parent position at fork time)
Definition branch.hpp:595
Allocation allocate()
Allocate a branch slot + KV lease atomically.
Definition branch.hpp:439
void add_cells_used(uint32_t n)
Increment cells_used counter (for standalone prefill/step outside BranchStore methods)
Definition branch.hpp:578
const BranchState * get(BranchHandle handle) const
Look up branch state by handle.
Definition branch.hpp:660
void free_sampler(SamplerChainHandle h)
Free a sampler chain.
Definition branch.hpp:703
void retainOnly(BranchHandle winner)
Keep only the winner — nuclear KV + CPU cleanup.
Definition branch.hpp:534
KvPressure kv_pressure() const
KV cache pressure snapshot — O(1), no tree walking.
Definition branch.hpp:570
SamplerChainHandle create_sampler(const P &params)
Create a sampler chain and register it.
Definition branch.hpp:672
llama_sampler * get_grammar_sampler(GrammarHandle h) const
Dereference a grammar handle (non-owning)
Definition branch.hpp:801
void add_sampling_surprisal(MetricsHandle h, float surprisal)
Add sampling-level surprisal to a metrics tracker.
Definition branch.hpp:860
void init_tenancy(llama_context *ctx)
Initialize KV tenancy after context creation.
Definition branch.hpp:499
GrammarHandle clone_grammar(GrammarHandle h)
Clone a grammar (for fork)
Definition branch.hpp:777
void decode_scatter(std::span< const DecodeScatterItem > items)
Decode variable token counts per branch with auto-chunking.
Definition branch.hpp:993
BranchStore(size_t initial_capacity=16)
Construct a branch store with initial slot capacity.
Definition branch.hpp:398
void free_metrics(MetricsHandle h)
Free a metrics tracker.
Definition branch.hpp:837
const std::vector< BranchHandle > & children(BranchHandle h) const
Get a branch's child handles.
Definition branch.hpp:605
BranchState * get(BranchHandle handle)
Look up branch state by handle.
Definition branch.hpp:640
BranchHandle parent(BranchHandle h) const
Get a branch's parent handle.
Definition branch.hpp:585
void add_model_surprisal(MetricsHandle h, float surprisal)
Add model-level surprisal to a metrics tracker.
Definition branch.hpp:846
llama_sampler * get_sampler_chain(SamplerChainHandle h) const
Dereference a sampler chain handle (non-owning)
Definition branch.hpp:712
~Branch()
Destructor — CASCADE prunes entire subtree.
Definition branch.hpp:2417
void accept(llama_token token)
Accept a token — advance grammar, penalty window, and metrics.
Definition branch.hpp:2510
Branch & operator=(const Branch &)=delete
int n_vocab() const
Vocabulary size.
Definition branch.hpp:2543
void setGrammar(const char *grammar_str)
Replace grammar constraint (nullptr/empty to remove)
Definition branch.hpp:2529
const std::vector< BranchHandle > & childHandles() const
Child branch handles (empty if leaf)
Definition branch.hpp:2554
void setSamplerParams(const P &params)
Replace sampler chain with new parameters (memoized)
Definition branch.hpp:2523
BranchHandle handle() const
Underlying opaque handle (for interop with free functions)
Definition branch.hpp:2547
const float * logits() const
Get the branch's captured logits snapshot.
Definition branch.hpp:2497
bool is_eog(llama_token token) const
Check if a token is end-of-generation for this branch's model.
Definition branch.hpp:2515
BranchHandle parentHandle() const
Parent branch handle, or INVALID_HANDLE if root.
Definition branch.hpp:2552
Branch(BranchStore *store, BranchHandle handle)
Definition branch.hpp:2413
llama_pos position() const
Current decode position (token count)
Definition branch.hpp:2537
llama_token sample()
Sample a token from captured logits.
Definition branch.hpp:2504
void step(llama_token token)
Decode one token and capture logits (generation step)
Definition branch.hpp:2491
void pruneSubtree()
CASCADE prune — removes entire subtree.
Definition branch.hpp:2471
float perplexity() const
Model-level perplexity (from raw logits, pre-filter)
Definition branch.hpp:2541
Branch(const Branch &)=delete
bool isActive() const
True if this branch holds a KV lease.
Definition branch.hpp:2561
llama_pos forkHead() const
Parent's position at fork time (0 for root branches)
Definition branch.hpp:2539
void force_snapshot_logits()
Force-copy shared logits buffer into this branch's snapshot.
Definition branch.hpp:2478
bool isLeaf() const
True if this branch has no children.
Definition branch.hpp:2559
void prefill(const llama_token *tokens, size_t n)
Decode multiple tokens and capture logits atomically (prompt injection)
Definition branch.hpp:2485
bool valid() const
True if this Branch holds a valid handle.
Definition branch.hpp:2545
void prune()
RESTRICT prune (throws if children exist)
Definition branch.hpp:2465
Branch fork()
Fork: allocates slot + lease, records topology edge.
Definition branch.hpp:2459
Branch & operator=(Branch &&other) noexcept
Definition branch.hpp:2428
static Branch create(llama_context *ctx, const llama_model *model, BranchStore &store, llama_pos start_pos, const P &params, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Factory: allocates slot + lease from store.
Definition branch.hpp:2445
Branch(Branch &&other) noexcept
Definition branch.hpp:2423
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:47
Batch Decoding Operations.
Grammar-Constrained Sampling.
constexpr llama_seq_id NO_LEASE
Sentinel value indicating a branch has no KV residency.
Definition kv.hpp:206
KV Cache Physics.
Zero-copy logits access with clear lifetime semantics.
Distribution Metrics for Test-Time Alignment.
void prefill(BranchHandle handle, const llama_token *tokens, size_t n_tokens, BranchStore &s)
Decode multiple tokens and capture logits atomically (prompt prefill)
Definition branch.hpp:1711
void set_logit_bias(BranchHandle handle, const llama_logit_bias *biases, size_t n_biases, BranchStore &s)
Definition branch.hpp:1395
float get_perplexity(BranchHandle handle, BranchStore &s)
Get model-level perplexity (from raw logits)
Definition branch.hpp:2307
void apply_grammar(BranchHandle handle, float *logits, int n_vocab, BranchStore &s)
Apply grammar constraints to an external logits buffer.
Definition branch.hpp:1981
int32_t GrammarHandle
Handle to a grammar sampler in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:154
uint32_t BranchHandle
Opaque handle to a branch slot.
Definition branch.hpp:94
void step(BranchHandle handle, llama_token token, BranchStore &s)
Decode a single token and capture logits (generation step)
Definition branch.hpp:1756
constexpr int DEFAULT_N_BATCH
Default batch size for decode operations.
Definition branch.hpp:98
void set_grammar_lazy(BranchHandle handle, const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, BranchStore &s)
Set lazy grammar on a branch (unconstrained until trigger fires)
Definition branch.hpp:1591
void set_steer(BranchHandle handle, std::function< void(llama_token_data_array &)> steer_fn, BranchStore &s)
Definition branch.hpp:1462
constexpr uint32_t GEN_SHIFT
Bit shift for generation field.
Definition branch.hpp:99
void accept_token(BranchHandle handle, llama_token token, BranchStore &s)
Accept a sampled token, advancing grammar and sampler state.
Definition branch.hpp:1913
void force_snapshot_logits(BranchHandle handle, BranchStore &s)
Force-copy the shared llama.cpp logits buffer into this branch's private snapshot.
Definition branch.hpp:1675
llama_pos get_fork_head(BranchHandle handle, BranchStore &s)
Get the branch's fork head (parent position at fork time)
Definition branch.hpp:2291
float get_sampling_perplexity(BranchHandle handle, BranchStore &s)
Get sampling-level perplexity (from filtered distribution)
Definition branch.hpp:2327
float get_legal_logsumexp(BranchHandle handle, BranchStore &s)
Compute log-sum-exp over grammar-legal logits.
Definition branch.hpp:2116
void clear_logit_bias(BranchHandle handle, BranchStore &s)
Clear all logit biases from a branch.
Definition branch.hpp:1421
uint16_t handle_generation(BranchHandle h)
Extract generation counter from a branch handle.
Definition branch.hpp:134
const float * get_logits(BranchHandle handle, BranchStore &s)
Get the branch's captured logits snapshot.
Definition branch.hpp:1795
void set_sampler_params(BranchHandle handle, const P &params, BranchStore &s)
Replace a branch's sampler chain with new parameters.
Definition branch.hpp:1516
void prune(BranchHandle handle, BranchStore &s)
Prune a leaf branch (RESTRICT — throws if children exist)
Definition branch.hpp:1627
BranchHandle create(llama_context *ctx, const llama_model *model, BranchStore &s, llama_pos start_pos, const P &params, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Create a new branch with sampler chain, optional grammar, and metrics.
Definition branch.hpp:1225
CachedSamplingParams snapshot_params(const P &p)
Snapshot sampling params for memoization comparison.
Definition branch.hpp:237
constexpr uint32_t INDEX_MASK
Mask for slot index field.
Definition branch.hpp:100
BranchHandle fork(BranchHandle source, BranchStore &s)
Fork a branch into a new independent sequence.
Definition branch.hpp:1300
int get_n_vocab(BranchHandle handle, BranchStore &s)
Get the branch's vocabulary size.
Definition branch.hpp:2383
llama_token sample(BranchHandle handle, BranchStore &s)
Sample a token from the branch's captured logits.
Definition branch.hpp:1821
void pruneSubtree(BranchHandle h, BranchStore &s)
Prune a branch and all descendants (CASCADE — iterative post-order)
Definition branch.hpp:1644
constexpr BranchHandle INVALID_HANDLE
Null handle sentinel.
Definition branch.hpp:96
float get_token_prior(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token, checking grammar legality first.
Definition branch.hpp:2259
constexpr llama_seq_id NO_LEASE
Branch has no KV residency.
Definition branch.hpp:97
int32_t MetricsHandle
Handle to a metrics tracker in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:157
llama_pos get_position(BranchHandle handle, BranchStore &s)
Get the branch's current decode position.
Definition branch.hpp:2279
void set_grammar(BranchHandle handle, const llama_model *model, const char *grammar_str, BranchStore &s)
Replace a branch's grammar constraint.
Definition branch.hpp:1552
void clear_steer(BranchHandle handle, BranchStore &s)
Clear the steer callback from a branch.
Definition branch.hpp:1485
float get_token_prior_assume_legal(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token known to be grammar-legal.
Definition branch.hpp:2228
BranchHandle make_handle(uint16_t index, uint16_t generation)
Construct a branch handle from index and generation.
Definition branch.hpp:144
uint16_t handle_index(BranchHandle h)
Extract slot index from a branch handle.
Definition branch.hpp:125
int32_t SamplerChainHandle
Handle to a sampler chain in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:151
std::vector< std::pair< llama_token, float > > get_legal_priors(BranchHandle handle, BranchStore &s)
Get grammar-legal tokens with renormalized probabilities.
Definition branch.hpp:2037
bool is_token_legal(BranchHandle handle, llama_token token, BranchStore &s)
Check if a token is legal under grammar constraints.
Definition branch.hpp:2177
float get_last_sampling_prior(BranchHandle handle, BranchStore &s)
Get the last sampled token's prior from the filtered distribution.
Definition branch.hpp:2346
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
Definition decode.hpp:480
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
Definition decode.hpp:238
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
Definition decode.hpp:124
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
Definition decode.hpp:395
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
Definition decode.hpp:339
constexpr T as_value(const X &x, T def)
Extract value from either T or std::optional<T> with fallback.
Definition sampler.hpp:50
void free_sampler(llama_sampler *smpl)
Free a grammar sampler.
Definition grammar.hpp:234
llama_sampler * clone_sampler(llama_sampler *smpl)
Clone a grammar sampler (for fork/branching).
Definition grammar.hpp:212
llama_sampler * init_sampler(const llama_model *model, const std::string &grammar_str, const std::string &root_rule="root")
Initialize a grammar sampler from GBNF grammar string.
Definition grammar.hpp:106
llama_sampler * init_lazy_sampler(const llama_model *model, const std::string &grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const std::string &root_rule="root")
Initialize a lazy grammar sampler from GBNF grammar string.
Definition grammar.hpp:157
void accept(llama_sampler *smpl, llama_token token)
Accept a token into grammar state.
Definition grammar.hpp:262
void apply(llama_sampler *smpl, llama_token_data_array *cur_p)
Apply grammar constraint to candidates.
Definition grammar.hpp:248
void evict_all(State &s)
Evict every leased seq_id.
Definition kv.hpp:350
llama_seq_id acquire(State &s)
Acquire a seq_id from the vacant pool.
Definition kv.hpp:255
size_t available(const State &s)
Number of vacant seq_ids available for acquisition.
Definition kv.hpp:364
void evict(State &s, llama_seq_id seq)
Evict a seq_id — strip all KV tags then release.
Definition kv.hpp:296
void retain(State &s, llama_seq_id keep)
Nuclear retain — keep one seq, rebuild vacancy from scratch.
Definition kv.hpp:319
State init(llama_context *ctx, llama_seq_id n_seq_max)
Initialize tenancy with all seq_ids vacant.
Definition kv.hpp:233
void release(State &s, llama_seq_id seq)
Release a seq_id back to vacant — bookkeeping only, no KV calls.
Definition kv.hpp:274
void seq_cp(llama_context *ctx, llama_seq_id src, llama_seq_id dst, llama_pos p0=0, llama_pos p1=-1)
Copy KV cache from one sequence to another.
Definition kv.hpp:137
llama_pos pos_max(llama_context *ctx, llama_seq_id seq)
Get maximum position in KV cache sequence.
Definition kv.hpp:110
float * get(llama_context *ctx, int32_t index=-1)
Definition logits.hpp:78
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
Definition metrics.hpp:226
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
Definition metrics.hpp:131
void apply(llama_sampler *chain, llama_token_data_array *cur_p)
Apply a sampler chain to a candidate array.
Definition sampler.hpp:581
void accept(llama_sampler *chain, llama_token token)
Accept a token into the sampler chain.
Definition sampler.hpp:595
llama_sampler * clone_chain(llama_sampler *chain)
Clone a sampler chain.
Definition sampler.hpp:526
llama_sampler * create_chain(const P &params)
Create a persistent sampler chain from parameters.
Definition sampler.hpp:465
void free_chain(llama_sampler *chain)
Free a sampler chain.
Definition sampler.hpp:567
bool is_eog(const llama_vocab *vocab, llama_token token)
Check if token is end-of-generation marker.
Token Sampling Operations.
Consolidated mutable state for a single branch.
Definition branch.hpp:275
bool in_use
True when slot is allocated to an active branch.
Definition branch.hpp:315
const llama_model * model
Llama model (not owned, must outlive branch)
Definition branch.hpp:277
bool has_logits
True only after force_snapshot_logits(), prefill(), or step()
Definition branch.hpp:299
std::function< void(llama_token_data_array &)> steer_fn
Dynamic logit callback, NOT cloned on fork.
Definition branch.hpp:291
int n_batch
Batch size for decode operations.
Definition branch.hpp:311
llama_seq_id seq_id
KV cache sequence identifier (NO_LEASE when inactive)
Definition branch.hpp:279
llama_pos position
Current decode position in the sequence.
Definition branch.hpp:280
std::vector< llama_logit_bias > logit_bias
Static token biases, cloned on fork.
Definition branch.hpp:290
std::vector< llama_token_data > last_candidates
Filtered candidates from last sample()
Definition branch.hpp:296
std::vector< float > logits_snapshot
Captured logit distribution (n_vocab floats)
Definition branch.hpp:298
std::vector< llama_token_data > candidates_buffer
Reusable scratch buffer for sampling (avoids O(n_vocab) allocs per sample call).
Definition branch.hpp:309
boundaries::BoundaryTracker * boundary_tracker
Token boundary detector (owned, optional)
Definition branch.hpp:288
SamplerChainHandle sampler_chain
Handle into BranchStore's sampler registry.
Definition branch.hpp:283
llama_context * ctx
Llama context (not owned, must outlive branch)
Definition branch.hpp:276
BranchHandle parent
Parent branch (INVALID_HANDLE if root)
Definition branch.hpp:318
int n_vocab
Vocabulary size (cached for buffer pre-allocation)
Definition branch.hpp:312
GrammarHandle grammar
Handle into BranchStore's grammar registry.
Definition branch.hpp:284
std::vector< BranchHandle > children
Child branches forked from this one.
Definition branch.hpp:319
llama_token last_token
Last token returned by sample()
Definition branch.hpp:295
MetricsHandle metrics
Handle into BranchStore's metrics registry.
Definition branch.hpp:293
llama_pos fork_head
Parent's position at fork time (0 for root branches)
Definition branch.hpp:281
CachedSamplingParams cached_params
Params used to create current chain (for memoization)
Definition branch.hpp:286
uint16_t generation
Slot generation counter (for ABA prevention)
Definition branch.hpp:314
Result of allocate(): a slot handle + its leased seq_id.
Definition branch.hpp:429
Concrete sampling params snapshot for memoization.
Definition branch.hpp:215
bool operator==(const CachedSamplingParams &) const =default
Item for decode_each: one token per branch.
Definition branch.hpp:330
Item for decode_scatter: variable tokens per branch.
Definition branch.hpp:348
std::span< const llama_token > tokens
Definition branch.hpp:350
RAII entry for a grammar sampler in the registry.
Definition branch.hpp:190
GrammarEntry(GrammarEntry &&o) noexcept
Definition branch.hpp:196
GrammarEntry & operator=(const GrammarEntry &)=delete
GrammarEntry & operator=(GrammarEntry &&o) noexcept
Definition branch.hpp:197
GrammarEntry(const GrammarEntry &)=delete
llama_sampler * sampler
Definition branch.hpp:191
Snapshot of KV cache pressure from BranchStore.
Definition branch.hpp:114
uint32_t remaining
n_ctx - cells_used (clamped to 0)
Definition branch.hpp:117
uint32_t n_ctx
Total KV capacity.
Definition branch.hpp:115
uint32_t cells_used
Cells allocated since last reset.
Definition branch.hpp:116
RAII entry for a sampler chain in the registry.
Definition branch.hpp:165
SamplerChainEntry(const SamplerChainEntry &)=delete
SamplerChainEntry & operator=(SamplerChainEntry &&o) noexcept
Definition branch.hpp:174
SamplerChainEntry(SamplerChainEntry &&o) noexcept
Definition branch.hpp:172
bool has_dist
True if chain ends with dist (temp > 0), false if greedy.
Definition branch.hpp:167
SamplerChainEntry & operator=(const SamplerChainEntry &)=delete
llama_context * ctx
Context for KV operations (nullptr after drain)
Definition kv.hpp:217
Unified model + sampling perplexity tracker.
Definition metrics.hpp:99