liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
branch.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6
57#include "boundaries.hpp"
58#include "common.hpp"
59#include "decode.hpp"
60#include "grammar.hpp"
61#include "kv.hpp"
62#include "logits.hpp"
63#include "metrics.hpp"
64#include "sampler.hpp"
65
66#include <llama/llama.h>
67#include <algorithm> // std::remove
68#include <cassert> // assert
69#include <cmath> // std::exp, std::log, std::isinf, std::isfinite
70#include <cstdint>
71#include <cstring> // std::memcpy
72#include <ctime> // std::time
73#include <deque> // std::deque (pointer stability for BranchStore)
74#include <functional> // std::function
75#include <limits> // std::numeric_limits
76#include <mutex>
77#include <span> // std::span (C++20)
78#include <stdexcept> // std::runtime_error
79#include <string> // std::to_string
80#include <utility> // std::pair, std::exchange
81#include <vector>
82
83namespace lloyal::branch {
84
85// ===== HANDLE TYPE =====
86
95using BranchHandle = uint32_t;
96
98constexpr llama_seq_id NO_LEASE = kv::NO_LEASE;
99constexpr int DEFAULT_N_BATCH = 512;
100constexpr uint32_t GEN_SHIFT = 16;
101constexpr uint32_t INDEX_MASK = 0xFFFF;
102
115 uint32_t n_ctx;
116 uint32_t cells_used;
117 uint32_t remaining;
118};
119
125inline uint16_t handle_index(BranchHandle h) {
126 return static_cast<uint16_t>(h & INDEX_MASK);
127}
128
134inline uint16_t handle_generation(BranchHandle h) {
135 return static_cast<uint16_t>(h >> GEN_SHIFT);
136}
137
144inline BranchHandle make_handle(uint16_t index, uint16_t generation) {
145 return (static_cast<uint32_t>(generation) << GEN_SHIFT) | index;
146}
147
148// ===== REGISTRY HANDLE TYPES =====
149
151using SamplerChainHandle = int32_t;
152
154using GrammarHandle = int32_t;
155
157using MetricsHandle = int32_t;
158
166 llama_sampler* chain = nullptr;
167 bool has_dist = false;
168
169 SamplerChainEntry() = default;
171
173 : chain(o.chain), has_dist(o.has_dist) { o.chain = nullptr; }
175 if (this != &o) {
177 chain = o.chain; has_dist = o.has_dist; o.chain = nullptr;
178 }
179 return *this;
180 }
183};
184
191 llama_sampler* sampler = nullptr;
192
193 GrammarEntry() = default;
195
196 GrammarEntry(GrammarEntry&& o) noexcept : sampler(o.sampler) { o.sampler = nullptr; }
198 if (this != &o) {
200 sampler = o.sampler; o.sampler = nullptr;
201 }
202 return *this;
203 }
204 GrammarEntry(const GrammarEntry&) = delete;
206};
207
216 float temperature = 0.8f;
217 int32_t top_k = 40;
218 float top_p = 0.95f;
219 float typical_p = 1.0f;
220 float min_p = 0.05f;
221 float penalty_repeat = 1.0f;
222 float penalty_freq = 0.0f;
223 float penalty_present = 0.0f;
224 int32_t penalty_last_n = 64;
225 uint32_t seed = 0;
226 bool operator==(const CachedSamplingParams&) const = default;
227};
228
236template <SamplingParamsLike P>
238 using ::lloyal::detail::as_value;
240 as_value(p.temperature, 0.8f),
241 as_value(p.top_k, static_cast<int32_t>(40)),
242 as_value(p.top_p, 0.95f),
243 as_value(p.typical_p, 1.0f),
244 as_value(p.min_p, 0.05f),
245 as_value(p.penalty_repeat, 1.0f),
246 as_value(p.penalty_freq, 0.0f),
247 as_value(p.penalty_present, 0.0f),
248 as_value(p.penalty_last_n, static_cast<int32_t>(64)),
249 as_value(p.seed, static_cast<uint32_t>(0)),
250 };
251}
252
253// ===== BRANCH STATE =====
254
276 llama_context* ctx = nullptr;
277 const llama_model* model = nullptr;
278
279 llama_seq_id seq_id = NO_LEASE;
280 llama_pos position = 0;
281 llama_pos fork_head = 0;
282
285
287
289
290 std::vector<llama_logit_bias> logit_bias;
291 std::function<void(llama_token_data_array&)> steer_fn;
292
294
295 llama_token last_token = -1;
296 std::vector<llama_token_data> last_candidates;
297
298 std::vector<float> logits_snapshot;
299 bool has_logits = false;
300
309 std::vector<llama_token_data> candidates_buffer;
310
312 int n_vocab = 0;
313
314 uint16_t generation = 0;
315 bool in_use = false;
316
317 // Topology — maintained by fork/prune/pruneSubtree
319 std::vector<BranchHandle> children;
320};
321
322// ===== BATCHED DECODE ITEM TYPES =====
323
334
350 std::span<const llama_token> tokens;
351};
352
353// ===== BRANCH STORE (HANDLE TABLE) =====
354
393public:
398 explicit BranchStore(size_t initial_capacity = 16) {
399 // Ensure minimum capacity of 2 (slot 0 reserved + at least 1 usable)
400 if (initial_capacity < 2) {
401 initial_capacity = 2;
402 }
403 slots_.resize(initial_capacity);
404 // Slot 0 is reserved (handle 0 = invalid)
405 slots_[0].in_use = true;
406 slots_[0].generation = 0xFFFF; // Never valid
407
408 // Initialize freelist with remaining slots
409 // NOTE: Use i-- > 1 pattern to avoid size_t underflow
410 for (size_t i = initial_capacity; i-- > 1; ) {
411 freelist_.push_back(static_cast<uint16_t>(i));
412 }
413 }
414
418 if (tenancy_.ctx != nullptr) {
419 LLOYAL_LOG_DEBUG("[BranchStore] WARNING: not drained before destruction");
420 }
421 for (size_t i = 1; i < slots_.size(); ++i) {
422 if (slots_[i].in_use) {
423 free_branch_resources(slots_[i]);
424 }
425 }
426 }
427
429 struct Allocation { BranchHandle handle; llama_seq_id seq_id; };
430
440 if (tenancy_.ctx == nullptr) return {INVALID_HANDLE, NO_LEASE}; // drained
441 llama_seq_id seq = kv::tenancy::acquire(tenancy_);
442 if (seq < 0) return {INVALID_HANDLE, NO_LEASE};
443 BranchHandle handle = allocate_slot();
444 if (handle == INVALID_HANDLE) {
445 // Seq was never used — bookkeeping-only return, no KV calls
446 kv::tenancy::release(tenancy_, seq);
447 return {INVALID_HANDLE, NO_LEASE};
448 }
449 // Stamp seq_id on slot so release() can evict properly
450 BranchState* st = get(handle);
451 st->seq_id = seq;
452 return {handle, seq};
453 }
454
463 void release(BranchHandle handle) {
464 if (handle == INVALID_HANDLE) return;
465 BranchState* st = get(handle);
466 if (!st) return;
467 // Eager edge cleanup: remove from parent's children
468 if (st->parent != INVALID_HANDLE) {
469 BranchState* p = get(st->parent);
470 if (p) {
471 auto& c = p->children;
472 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
473 }
474 }
475 // Subtract unique cells owned by this branch (above fork_head)
476 if (st->position > st->fork_head) {
477 uint32_t unique = static_cast<uint32_t>(st->position - st->fork_head);
478 cells_used_ = (unique <= cells_used_) ? cells_used_ - unique : 0;
479 }
480 // Evict lease (KV strip + bookkeeping)
481 if (st->seq_id != NO_LEASE)
482 kv::tenancy::evict(tenancy_, st->seq_id);
483 free_branch_resources(*st);
484 reset_slot(*st);
485 freelist_.push_back(handle_index(handle));
486
487 // All branches released → KV cache empty, reset pressure counter
488 if (freelist_.size() == slots_.size() - 1) {
489 cells_used_ = 0;
490 }
491 }
492
493 // ===== TENANCY LIFECYCLE =====
494
499 void init_tenancy(llama_context* ctx) {
500 tenancy_ = kv::tenancy::init(ctx, llama_n_seq_max(ctx));
501 cells_used_ = 0;
502 }
503
512 void drain() {
513 if (tenancy_.ctx == nullptr) return; // idempotent
514 kv::tenancy::evict_all(tenancy_);
515 for (size_t i = 1; i < slots_.size(); ++i) {
516 if (slots_[i].in_use) {
517 free_branch_resources(slots_[i]);
518 reset_slot(slots_[i]);
519 }
520 }
521 tenancy_.ctx = nullptr; // marks as drained
522 cells_used_ = 0;
523 }
524
535 BranchState* w = get(winner);
536 if (!w) throw std::runtime_error("retainOnly: invalid winner handle");
537 if (w->seq_id == NO_LEASE) throw std::runtime_error("retainOnly: winner has no lease");
538 kv::tenancy::retain(tenancy_, w->seq_id); // nuclear KV pass
539 // Collect losers first — don't mutate while iterating
540 std::vector<BranchHandle> losers;
541 for (size_t i = 1; i < slots_.size(); ++i) {
542 if (!slots_[i].in_use) continue;
543 BranchHandle h = make_handle(static_cast<uint16_t>(i), slots_[i].generation);
544 if (h == winner) continue;
545 losers.push_back(h);
546 }
547 for (auto h : losers)
548 release_slot_only(h); // CPU only, KV already stripped
550 w->fork_head = 0;
551 w->children.clear();
552 cells_used_ = static_cast<uint32_t>(w->position);
553 }
554
555 // ===== TOPOLOGY QUERIES =====
556
561 size_t available() const { return kv::tenancy::available(tenancy_); }
562
571 if (!tenancy_.ctx) return {0, 0, 0};
572 uint32_t n_ctx = static_cast<uint32_t>(llama_n_ctx(tenancy_.ctx));
573 uint32_t remaining = (n_ctx > cells_used_) ? n_ctx - cells_used_ : 0;
574 return { n_ctx, cells_used_, remaining };
575 }
576
578 void add_cells_used(uint32_t n) { cells_used_ += n; }
579
586 const BranchState* st = get(h);
587 return st ? st->parent : INVALID_HANDLE;
588 }
589
595 llama_pos fork_head(BranchHandle h) const {
596 const BranchState* st = get(h);
597 return st ? st->fork_head : 0;
598 }
599
605 const std::vector<BranchHandle>& children(BranchHandle h) const {
606 static const std::vector<BranchHandle> empty;
607 const BranchState* st = get(h);
608 return st ? st->children : empty;
609 }
610
616 bool isLeaf(BranchHandle h) const {
617 const BranchState* st = get(h);
618 return st ? st->children.empty() : true;
619 }
620
626 bool isActive(BranchHandle h) const {
627 const BranchState* st = get(h);
628 return st ? (st->seq_id != NO_LEASE) : false;
629 }
630
641 if (handle == INVALID_HANDLE) return nullptr;
642
643 uint16_t index = handle_index(handle);
644 uint16_t gen = handle_generation(handle);
645
646 // Slot 0 is reserved and never valid for external use
647 if (index == 0) return nullptr;
648
649 if (index >= slots_.size()) return nullptr;
650
651 BranchState& slot = slots_[index];
652 if (!slot.in_use || slot.generation != gen) {
653 return nullptr;
654 }
655
656 return &slot;
657 }
658
660 const BranchState* get(BranchHandle handle) const {
661 return const_cast<BranchStore*>(this)->get(handle);
662 }
663
664 // ===== SAMPLER CHAIN REGISTRY =====
665
671 template <SamplingParamsLike P>
673 SamplerChainHandle h = next_sampler_handle_++;
674 SamplerChainEntry entry;
675 entry.chain = sampler::create_chain(params);
676 float temperature = ::lloyal::detail::as_value(params.temperature, 0.8f);
677 entry.has_dist = (temperature > 0.0f);
678 sampler_chains_.emplace(h, std::move(entry));
679 return h;
680 }
681
688 if (h == 0) return 0;
689 auto it = sampler_chains_.find(h);
690 if (it == sampler_chains_.end()) return 0;
691 SamplerChainHandle nh = next_sampler_handle_++;
692 SamplerChainEntry entry;
693 entry.chain = sampler::clone_chain(it->second.chain);
694 entry.has_dist = it->second.has_dist;
695 sampler_chains_.emplace(nh, std::move(entry));
696 return nh;
697 }
698
704 if (h != 0) sampler_chains_.erase(h);
705 }
706
712 llama_sampler* get_sampler_chain(SamplerChainHandle h) const {
713 if (h == 0) return nullptr;
714 auto it = sampler_chains_.find(h);
715 return it != sampler_chains_.end() ? it->second.chain : nullptr;
716 }
717
724 if (h == 0) return false;
725 auto it = sampler_chains_.find(h);
726 return it != sampler_chains_.end() ? it->second.has_dist : false;
727 }
728
729 // ===== GRAMMAR REGISTRY =====
730
738 GrammarHandle create_grammar(const llama_model* model,
739 const char* grammar_str,
740 const char* root = "root") {
741 GrammarHandle h = next_grammar_handle_++;
742 GrammarEntry entry;
743 entry.sampler = grammar::init_sampler(model, grammar_str, root);
744 grammars_.emplace(h, std::move(entry));
745 return h;
746 }
747
758 const llama_model* model,
759 const char* grammar_str,
760 const std::vector<std::string>& trigger_patterns,
761 const std::vector<llama_token>& trigger_tokens,
762 const char* root = "root") {
763 GrammarHandle h = next_grammar_handle_++;
764 GrammarEntry entry;
766 model, grammar_str, trigger_patterns, trigger_tokens, root);
767 if (!entry.sampler) return 0;
768 grammars_.emplace(h, std::move(entry));
769 return h;
770 }
771
778 if (h == 0) return 0;
779 auto it = grammars_.find(h);
780 if (it == grammars_.end()) return 0;
781 GrammarHandle nh = next_grammar_handle_++;
782 GrammarEntry entry;
783 entry.sampler = grammar::clone_sampler(it->second.sampler);
784 grammars_.emplace(nh, std::move(entry));
785 return nh;
786 }
787
793 if (h != 0) grammars_.erase(h);
794 }
795
801 llama_sampler* get_grammar_sampler(GrammarHandle h) const {
802 if (h == 0) return nullptr;
803 auto it = grammars_.find(h);
804 return it != grammars_.end() ? it->second.sampler : nullptr;
805 }
806
807 // ===== METRICS REGISTRY =====
808
814 MetricsHandle h = next_metrics_handle_++;
815 metrics_registry_[h] = metrics::BranchMetricsState{};
816 return h;
817 }
818
825 if (h == 0) return 0;
826 auto it = metrics_registry_.find(h);
827 if (it == metrics_registry_.end()) return 0;
828 MetricsHandle nh = next_metrics_handle_++;
829 metrics_registry_[nh] = it->second;
830 return nh;
831 }
832
838 if (h != 0) metrics_registry_.erase(h);
839 }
840
846 void add_model_surprisal(MetricsHandle h, float surprisal) {
847 if (h == 0) return;
848 auto it = metrics_registry_.find(h);
849 if (it == metrics_registry_.end()) return;
850 if (!std::isfinite(surprisal)) return;
851 it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
852 it->second.model.count++;
853 }
854
860 void add_sampling_surprisal(MetricsHandle h, float surprisal) {
861 if (h == 0) return;
862 auto it = metrics_registry_.find(h);
863 if (it == metrics_registry_.end()) return;
864 if (!std::isfinite(surprisal)) return;
865 it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
866 it->second.sampling.count++;
867 }
868
875 if (h == 0) return std::numeric_limits<float>::infinity();
876 auto it = metrics_registry_.find(h);
877 if (it == metrics_registry_.end() || it->second.model.count == 0)
878 return std::numeric_limits<float>::infinity();
879 return std::exp(it->second.model.nll_sum_nats /
880 static_cast<float>(it->second.model.count));
881 }
882
889 if (h == 0) return std::numeric_limits<float>::infinity();
890 auto it = metrics_registry_.find(h);
891 if (it == metrics_registry_.end() || it->second.sampling.count == 0)
892 return std::numeric_limits<float>::infinity();
893 return std::exp(it->second.sampling.nll_sum_nats /
894 static_cast<float>(it->second.sampling.count));
895 }
896
897 // ===== BATCHED DECODE =====
898
921 void decode_each(std::span<const DecodeEachItem> items) {
922 if (items.empty()) return;
923
924 const int32_t n = static_cast<int32_t>(items.size());
925
926 // Validate handles and collect states
927 std::vector<BranchState*> states(n);
928 for (int32_t i = 0; i < n; ++i) {
929 states[i] = get(items[i].handle);
930 if (!states[i]) {
931 throw std::runtime_error("BranchStore::decode_each - invalid handle at index " + std::to_string(i));
932 }
933 if (i > 0 && states[i]->ctx != states[0]->ctx) {
934 throw std::runtime_error("BranchStore::decode_each - all branches must share the same context");
935 }
936 }
937
938 // Build EachItem array from branch states
939 std::vector<decode::EachItem> decode_items(n);
940 for (int32_t i = 0; i < n; ++i) {
941 decode_items[i].token = items[i].token;
942 decode_items[i].pos = states[i]->position;
943 decode_items[i].seq_id = states[i]->seq_id;
944 decode_items[i].output_logits = true;
945 }
946
947 // Single GPU dispatch
948 if (decode::each(states[0]->ctx, decode_items.data(), n, scratch_) != 0) {
949 throw std::runtime_error("BranchStore::decode_each - llama_decode failed");
950 }
951
952 // Capture logits and update positions
953 llama_context* ctx = states[0]->ctx;
954 for (int32_t i = 0; i < n; ++i) {
955 const float* raw_logits = logits::get(ctx, i); // throws on null
956 if (states[i]->n_vocab <= 0) {
957 throw std::runtime_error("BranchStore::decode_each - invalid vocab size at index " + std::to_string(i));
958 }
959 assert(states[i]->logits_snapshot.size() >= static_cast<size_t>(states[i]->n_vocab));
960 std::memcpy(states[i]->logits_snapshot.data(), raw_logits,
961 states[i]->n_vocab * sizeof(float));
962 states[i]->has_logits = true;
963 states[i]->position += 1;
964 }
965 cells_used_ += static_cast<uint32_t>(items.size());
966 }
967
993 void decode_scatter(std::span<const DecodeScatterItem> items) {
994 if (items.empty()) return;
995
996 const int32_t n = static_cast<int32_t>(items.size());
997
998 // Validate handles and collect states
999 std::vector<BranchState*> states(n);
1000 for (int32_t i = 0; i < n; ++i) {
1001 states[i] = get(items[i].handle);
1002 if (!states[i]) {
1003 throw std::runtime_error("BranchStore::decode_scatter - invalid handle at index " + std::to_string(i));
1004 }
1005 if (i > 0 && states[i]->ctx != states[0]->ctx) {
1006 throw std::runtime_error("BranchStore::decode_scatter - all branches must share the same context");
1007 }
1008 }
1009
1010 llama_context* ctx = states[0]->ctx;
1011 const int32_t batch_limit = static_cast<int32_t>(llama_n_batch(ctx));
1012
1013 // Build flat token spans — bin_pack skips empties internally
1014 std::vector<std::span<const llama_token>> spans(n);
1015 for (int32_t i = 0; i < n; ++i) {
1016 spans[i] = items[i].tokens;
1017 }
1018
1019 auto chunks = decode::bin_pack(spans.data(), n, batch_limit);
1020
1021 for (const auto& chunk : chunks) {
1022 if (chunk.oversized) {
1023 int32_t idx = chunk.indices[0];
1024 int32_t tc = static_cast<int32_t>(items[idx].tokens.size());
1025
1026 if (decode::many(ctx, items[idx].tokens.data(), tc,
1027 states[idx]->position, batch_limit,
1028 states[idx]->seq_id) != 0) {
1029 throw std::runtime_error("BranchStore::decode_scatter - decode::many failed for oversized item " + std::to_string(idx));
1030 }
1031
1032 const float* raw_logits = logits::get(ctx, -1);
1033 assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
1034 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1035 states[idx]->n_vocab * sizeof(float));
1036 states[idx]->has_logits = true;
1037 states[idx]->position += tc;
1038 continue;
1039 }
1040
1041 // Normal chunk — build ScatterItems and dispatch
1042 std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
1043 for (size_t k = 0; k < chunk.indices.size(); ++k) {
1044 int32_t idx = chunk.indices[k];
1045 scatter_items[k].tokens = items[idx].tokens;
1046 scatter_items[k].start_pos = states[idx]->position;
1047 scatter_items[k].seq_id = states[idx]->seq_id;
1048 scatter_items[k].output_logits = true;
1049 }
1050
1051 if (decode::scatter(ctx, scatter_items.data(),
1052 static_cast<int32_t>(scatter_items.size()),
1053 scratch_) != 0) {
1054 throw std::runtime_error("BranchStore::decode_scatter - decode::scatter failed");
1055 }
1056
1057 // Capture logits for each item in the chunk
1058 int32_t cursor = 0;
1059 for (size_t k = 0; k < scatter_items.size(); ++k) {
1060 int32_t idx = chunk.indices[k];
1061 int32_t item_n = static_cast<int32_t>(scatter_items[k].tokens.size());
1062
1063 const float* raw_logits = logits::get(ctx, cursor + item_n - 1);
1064 assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
1065 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1066 states[idx]->n_vocab * sizeof(float));
1067 states[idx]->has_logits = true;
1068 states[idx]->position += static_cast<int32_t>(items[idx].tokens.size());
1069
1070 cursor += item_n;
1071 }
1072 }
1073
1074 // Accumulate total tokens decoded across all items
1075 for (int32_t i = 0; i < n; ++i) {
1076 cells_used_ += static_cast<uint32_t>(items[i].tokens.size());
1077 }
1078 }
1079
1103 BranchHandle dst_handle,
1104 std::span<const BranchHandle> expert_handles,
1105 float alpha) {
1106 BranchState* dst = get(dst_handle);
1107 if (!dst) {
1108 throw std::runtime_error("BranchStore::merge_logits - invalid dst handle");
1109 }
1110 if (!dst->has_logits) {
1111 throw std::runtime_error("BranchStore::merge_logits - dst has no captured logits");
1112 }
1113 if (dst->n_vocab <= 0) {
1114 throw std::runtime_error("BranchStore::merge_logits - dst n_vocab invalid");
1115 }
1116
1117 const int32_t n_vocab = dst->n_vocab;
1118 float* dst_logits = dst->logits_snapshot.data();
1119
1120 for (size_t i = 0; i < expert_handles.size(); ++i) {
1121 BranchState* src = get(expert_handles[i]);
1122 if (!src) {
1123 throw std::runtime_error(
1124 "BranchStore::merge_logits - invalid expert handle at index " +
1125 std::to_string(i));
1126 }
1127 if (!src->has_logits) {
1128 throw std::runtime_error(
1129 "BranchStore::merge_logits - expert has no captured logits at index " +
1130 std::to_string(i));
1131 }
1132 if (src->n_vocab != n_vocab) {
1133 throw std::runtime_error(
1134 "BranchStore::merge_logits - n_vocab mismatch at index " +
1135 std::to_string(i));
1136 }
1137
1138 const float* src_logits = src->logits_snapshot.data();
1139 for (int32_t t = 0; t < n_vocab; ++t) {
1140 dst_logits[t] += alpha * src_logits[t];
1141 }
1142 }
1143 }
1144
1145private:
1147 void free_branch_resources(BranchState& slot) {
1148 if (slot.sampler_chain != 0) {
1150 slot.sampler_chain = 0;
1151 }
1152 if (slot.grammar != 0) {
1153 free_grammar(slot.grammar);
1154 slot.grammar = 0;
1155 }
1156 if (slot.boundary_tracker) {
1157 delete slot.boundary_tracker;
1158 slot.boundary_tracker = nullptr;
1159 }
1160 if (slot.metrics != 0) {
1161 free_metrics(slot.metrics);
1162 slot.metrics = 0;
1163 }
1164 }
1165
1167 void reset_slot(BranchState& slot) {
1168 slot.in_use = false;
1169 slot.generation = static_cast<uint16_t>(slot.generation + 1); // Prevent ABA
1170 slot.ctx = nullptr;
1171 slot.model = nullptr;
1172 slot.seq_id = NO_LEASE;
1173 slot.position = 0;
1174 slot.fork_head = 0;
1175 slot.sampler_chain = 0;
1176 slot.grammar = 0;
1177 slot.metrics = 0;
1178 slot.cached_params = CachedSamplingParams{};
1179 slot.last_token = -1;
1180 slot.last_candidates.clear();
1181 slot.logits_snapshot.clear();
1182 slot.has_logits = false;
1183 slot.logit_bias.clear();
1184 slot.steer_fn = nullptr;
1185 slot.candidates_buffer.clear();
1186 slot.n_batch = DEFAULT_N_BATCH;
1187 slot.n_vocab = 0;
1188 slot.parent = INVALID_HANDLE;
1189 slot.children.clear();
1190 }
1191
1193 BranchHandle allocate_slot() {
1194 if (freelist_.empty()) {
1195 size_t old_size = slots_.size();
1196 size_t new_size = old_size * 2;
1197 if (new_size > INDEX_MASK) {
1198 new_size = INDEX_MASK + 1;
1199 }
1200 if (old_size >= new_size) {
1201 LLOYAL_LOG_DEBUG("[branch::allocate_slot] Store full, cannot allocate");
1202 return INVALID_HANDLE;
1203 }
1204 slots_.resize(new_size);
1205 for (size_t i = new_size; i-- > old_size; ) {
1206 freelist_.push_back(static_cast<uint16_t>(i));
1207 }
1208 }
1209 uint16_t index = freelist_.back();
1210 freelist_.pop_back();
1211 BranchState& slot = slots_[index];
1212 slot.in_use = true;
1213 return make_handle(index, slot.generation);
1214 }
1215
1218 void release_slot_only(BranchHandle handle) {
1219 if (handle == INVALID_HANDLE) return;
1220 BranchState* st = get(handle);
1221 if (!st) return;
1222 // Eager edge cleanup
1223 if (st->parent != INVALID_HANDLE) {
1224 BranchState* p = get(st->parent);
1225 if (p) {
1226 auto& c = p->children;
1227 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
1228 }
1229 }
1230 free_branch_resources(*st);
1231 reset_slot(*st);
1232 freelist_.push_back(handle_index(handle));
1233 }
1234
1237 std::deque<BranchState> slots_;
1238 std::vector<uint16_t> freelist_;
1239
1243 kv::tenancy::State tenancy_;
1244
1248 uint32_t cells_used_ = 0;
1249
1252 decode::Scratch scratch_;
1253
1254 // ===== Handle registries (instance-scoped, not global static) =====
1255
1256 std::unordered_map<SamplerChainHandle, SamplerChainEntry> sampler_chains_;
1257 SamplerChainHandle next_sampler_handle_ = 1;
1258
1259 std::unordered_map<GrammarHandle, GrammarEntry> grammars_;
1260 GrammarHandle next_grammar_handle_ = 1;
1261
1262 std::unordered_map<MetricsHandle, metrics::BranchMetricsState> metrics_registry_;
1263 MetricsHandle next_metrics_handle_ = 1;
1264};
1265
1266
1267// ===== BRANCH API =====
1268
1289template <SamplingParamsLike P>
1291 llama_context* ctx,
1292 const llama_model* model,
1293 BranchStore& s,
1294 llama_pos start_pos,
1295 const P& params,
1296 int n_batch = DEFAULT_N_BATCH,
1297 const char* grammar_str = nullptr,
1298 boundaries::BoundaryTracker* boundary_tracker = nullptr) {
1299 if (!ctx || !model) {
1300 LLOYAL_LOG_DEBUG("[branch::create] NULL ctx or model");
1301 return INVALID_HANDLE;
1302 }
1303
1304 auto [handle, seq_id] = s.allocate();
1305 if (handle == INVALID_HANDLE) {
1306 return INVALID_HANDLE;
1307 }
1308
1309 BranchState* state = s.get(handle);
1310 if (!state) {
1311 s.release(handle);
1312 return INVALID_HANDLE;
1313 }
1314
1315 state->ctx = ctx;
1316 state->model = model;
1317 // seq_id already stamped by allocate()
1318 state->position = start_pos;
1319 state->n_batch = n_batch;
1320
1321 const llama_vocab* vocab = llama_model_get_vocab(model);
1322 state->n_vocab = llama_vocab_n_tokens(vocab);
1323 state->logits_snapshot.resize(state->n_vocab);
1324 state->has_logits = false;
1325 state->candidates_buffer.resize(state->n_vocab);
1326
1327 state->sampler_chain = s.create_sampler(params);
1328 state->cached_params = snapshot_params(params);
1329
1330 if (grammar_str && grammar_str[0] != '\0') {
1331 state->grammar = s.create_grammar(model, grammar_str);
1332 }
1333
1334 state->boundary_tracker = boundary_tracker;
1335 state->metrics = s.create_metrics();
1336
1337 LLOYAL_LOG_DEBUG("[branch::create] Created branch handle=%u seq=%d pos=%d",
1338 handle, seq_id, start_pos);
1339
1340 return handle;
1341}
1342
1343
1366 BranchState* src = s.get(source);
1367 if (!src) {
1368 LLOYAL_LOG_DEBUG("[branch::fork] Invalid source handle");
1369 return INVALID_HANDLE;
1370 }
1371
1372 auto [new_handle, new_seq_id] = s.allocate();
1373 if (new_handle == INVALID_HANDLE) {
1374 return INVALID_HANDLE;
1375 }
1376
1377 BranchState* dst = s.get(new_handle);
1378 if (!dst) {
1379 s.release(new_handle);
1380 return INVALID_HANDLE;
1381 }
1382
1383 // Copy basic state
1384 dst->ctx = src->ctx;
1385 dst->model = src->model;
1386 dst->seq_id = new_seq_id;
1387 dst->position = src->position;
1388 dst->fork_head = src->position;
1389 dst->n_batch = src->n_batch;
1390 dst->n_vocab = src->n_vocab;
1391
1392#ifndef NDEBUG
1393 assert(kv::pos_max(src->ctx, new_seq_id) < 0 && "tenancy: acquired seq must be clean");
1394 assert(dst->parent == INVALID_HANDLE && dst->children.empty() && "fresh slot must have no topology");
1395#endif
1396
1397 // Fork KV cache
1398 kv::seq_cp(src->ctx, src->seq_id, new_seq_id);
1399
1400 // Record topology
1401 dst->parent = source;
1402 src->children.push_back(new_handle);
1403
1404 // Clone sampler chain
1405 if (src->sampler_chain != 0) {
1407 }
1408 dst->cached_params = src->cached_params;
1409
1410 if (src->grammar != 0) {
1411 dst->grammar = s.clone_grammar(src->grammar);
1412 }
1413
1414 if (src->boundary_tracker) {
1415 dst->boundary_tracker = src->boundary_tracker->clone().release();
1416 }
1417
1418 if (src->metrics != 0) {
1419 dst->metrics = s.clone_metrics(src->metrics);
1420 }
1421
1422 dst->last_token = src->last_token;
1423 dst->last_candidates = src->last_candidates;
1424
1425 // logits_snapshot copy is intentional: fork's contract is "sample different
1426 // tokens from the same logit distribution." Without the copy, the child can't
1427 // sample without a redundant decode. Cost: n_vocab * 4 bytes (~512KB at 128k vocab).
1428 dst->logits_snapshot = src->logits_snapshot;
1429 dst->has_logits = src->has_logits;
1430 dst->logit_bias = src->logit_bias;
1431
1432 dst->candidates_buffer.resize(dst->n_vocab);
1433
1434 LLOYAL_LOG_DEBUG("[branch::fork] Forked handle=%u -> handle=%u seq=%d->%d",
1435 source, new_handle, src->seq_id, new_seq_id);
1436
1437 return new_handle;
1438}
1439
1460inline void set_logit_bias(
1461 BranchHandle handle,
1462 const llama_logit_bias* biases,
1463 size_t n_biases,
1464 BranchStore& s) {
1465
1466
1467 BranchState* state = s.get(handle);
1468 if (!state) {
1469 throw std::runtime_error("set_logit_bias: invalid branch handle");
1470 }
1471
1472 // Replace existing biases with new set
1473 state->logit_bias.assign(biases, biases + n_biases);
1474
1475 LLOYAL_LOG_DEBUG("[branch::set_logit_bias] Set %zu biases on handle=%u",
1476 n_biases, handle);
1477}
1478
1487 BranchHandle handle,
1488 BranchStore& s) {
1489
1490
1491 BranchState* state = s.get(handle);
1492 if (!state) {
1493 throw std::runtime_error("clear_logit_bias: invalid branch handle");
1494 }
1495
1496 state->logit_bias.clear();
1497
1498 LLOYAL_LOG_DEBUG("[branch::clear_logit_bias] Cleared biases on handle=%u", handle);
1499}
1500
1527inline void set_steer(
1528 BranchHandle handle,
1529 std::function<void(llama_token_data_array&)> steer_fn,
1530 BranchStore& s) {
1531
1532
1533 BranchState* state = s.get(handle);
1534 if (!state) {
1535 throw std::runtime_error("set_steer: invalid branch handle");
1536 }
1537
1538 state->steer_fn = std::move(steer_fn);
1539
1540 LLOYAL_LOG_DEBUG("[branch::set_steer] Set steer callback on handle=%u", handle);
1541}
1542
1550inline void clear_steer(
1551 BranchHandle handle,
1552 BranchStore& s) {
1553
1554
1555 BranchState* state = s.get(handle);
1556 if (!state) {
1557 throw std::runtime_error("clear_steer: invalid branch handle");
1558 }
1559
1560 state->steer_fn = nullptr;
1561
1562 LLOYAL_LOG_DEBUG("[branch::clear_steer] Cleared steer callback on handle=%u", handle);
1563}
1564
1580template <SamplingParamsLike P>
1581inline void set_sampler_params(BranchHandle handle, const P& params, BranchStore& s) {
1582 BranchState* state = s.get(handle);
1583 if (!state) {
1584 throw std::runtime_error("set_sampler_params: invalid branch handle");
1585 }
1586
1587 CachedSamplingParams new_params = snapshot_params(params);
1588 if (new_params == state->cached_params && state->sampler_chain != 0) {
1589 return; // Memoized — no rebuild needed
1590 }
1591
1592 // Free old chain
1593 if (state->sampler_chain != 0) {
1594 s.free_sampler(state->sampler_chain);
1595 }
1596
1597 // Create new chain
1598 state->sampler_chain = s.create_sampler(params);
1599 state->cached_params = new_params;
1600
1601 LLOYAL_LOG_DEBUG("[branch::set_sampler_params] Rebuilt chain on handle=%u temp=%.3f",
1602 handle, new_params.temperature);
1603}
1604
1617inline void set_grammar(
1618 BranchHandle handle,
1619 const llama_model* model,
1620 const char* grammar_str,
1621 BranchStore& s) {
1622 BranchState* state = s.get(handle);
1623 if (!state) {
1624 throw std::runtime_error("set_grammar: invalid branch handle");
1625 }
1626
1627 // Free old grammar
1628 if (state->grammar != 0) {
1629 s.free_grammar(state->grammar);
1630 state->grammar = 0;
1631 }
1632
1633 // Create new grammar if provided
1634 if (grammar_str && grammar_str[0] != '\0') {
1635 state->grammar = s.create_grammar(model, grammar_str);
1636 }
1637
1638 LLOYAL_LOG_DEBUG("[branch::set_grammar] %s grammar on handle=%u",
1639 state->grammar != 0 ? "Set" : "Cleared", handle);
1640}
1641
1657 BranchHandle handle,
1658 const llama_model* model,
1659 const char* grammar_str,
1660 const std::vector<std::string>& trigger_patterns,
1661 const std::vector<llama_token>& trigger_tokens,
1662 BranchStore& s) {
1663 BranchState* state = s.get(handle);
1664 if (!state) {
1665 throw std::runtime_error("set_grammar_lazy: invalid branch handle");
1666 }
1667
1668 if (state->grammar != 0) {
1669 s.free_grammar(state->grammar);
1670 state->grammar = 0;
1671 }
1672
1673 if (grammar_str && grammar_str[0] != '\0') {
1674 state->grammar = s.create_grammar_lazy(
1675 model, grammar_str, trigger_patterns, trigger_tokens);
1676 }
1677
1678 LLOYAL_LOG_DEBUG("[branch::set_grammar_lazy] %s grammar on handle=%u",
1679 state->grammar != 0 ? "Set" : "Cleared", handle);
1680}
1681
1692inline void prune(BranchHandle handle, BranchStore& s) {
1693 BranchState* state = s.get(handle);
1694 if (!state) return;
1695 if (!state->children.empty())
1696 throw std::runtime_error("prune: RESTRICT — branch has children. Use pruneSubtree() for CASCADE.");
1697 s.release(handle);
1698}
1699
1710 std::vector<BranchHandle> stack{h}, post_order;
1711 while (!stack.empty()) {
1712 BranchHandle cur = stack.back(); stack.pop_back();
1713 post_order.push_back(cur);
1714 BranchState* st = s.get(cur);
1715 if (st) for (auto child : st->children) stack.push_back(child);
1716 }
1717 for (auto it = post_order.rbegin(); it != post_order.rend(); ++it)
1718 prune(*it, s);
1719}
1720
1741
1742
1743 BranchState* state = s.get(handle);
1744 if (!state) {
1745 throw std::runtime_error("force_snapshot_logits: invalid branch handle");
1746 }
1747
1748 // logits::get() throws if ctx is null or logits unavailable
1749 const float* raw_logits = logits::get(state->ctx, -1);
1750
1751 if (state->n_vocab <= 0) {
1752 throw std::runtime_error("force_snapshot_logits: invalid vocab size");
1753 }
1754
1755 std::memcpy(state->logits_snapshot.data(), raw_logits,
1756 state->n_vocab * sizeof(float));
1757 state->has_logits = true;
1758}
1759
1776inline void prefill(
1777 BranchHandle handle,
1778 const llama_token* tokens,
1779 size_t n_tokens,
1780 BranchStore& s) {
1781
1782
1783 BranchState* state = s.get(handle);
1784 if (!state) {
1785 throw std::runtime_error("prefill: invalid branch handle");
1786 }
1787
1788 // Pass raw pointer directly - no vector copy needed
1789 if (decode::many(state->ctx, tokens, static_cast<int32_t>(n_tokens),
1790 state->position, state->n_batch, state->seq_id) != 0) {
1791 throw std::runtime_error("prefill: llama_decode failed");
1792 }
1793
1794 state->position += static_cast<llama_pos>(n_tokens);
1795 s.add_cells_used(static_cast<uint32_t>(n_tokens));
1796
1797 // logits::get() throws if logits unavailable
1798 const float* raw_logits = logits::get(state->ctx, -1);
1799
1800 if (state->n_vocab <= 0) {
1801 throw std::runtime_error("prefill: invalid vocab size");
1802 }
1803
1804 std::memcpy(state->logits_snapshot.data(), raw_logits,
1805 state->n_vocab * sizeof(float));
1806 state->has_logits = true;
1807}
1808
1821inline void step(
1822 BranchHandle handle,
1823 llama_token token,
1824 BranchStore& s) {
1825
1826
1827 BranchState* state = s.get(handle);
1828 if (!state) {
1829 throw std::runtime_error("step: invalid branch handle");
1830 }
1831
1832 if (decode::one(state->ctx, token, state->position, state->seq_id, true) != 0) {
1833 throw std::runtime_error("step: llama_decode failed");
1834 }
1835 state->position += 1;
1836 s.add_cells_used(1);
1837
1838 // logits::get() throws if logits unavailable
1839 const float* raw_logits = logits::get(state->ctx, -1);
1840
1841 if (state->n_vocab <= 0) {
1842 throw std::runtime_error("step: invalid vocab size");
1843 }
1844
1845 std::memcpy(state->logits_snapshot.data(), raw_logits,
1846 state->n_vocab * sizeof(float));
1847 state->has_logits = true;
1848}
1849
1860inline const float* get_logits(BranchHandle handle, BranchStore& s) {
1861
1862
1863 const BranchState* state = s.get(handle);
1864 // Must check has_logits, not just empty() - buffer is pre-allocated in create()
1865 if (!state || !state->has_logits) {
1866 return nullptr;
1867 }
1868
1869 return state->logits_snapshot.data();
1870}
1871
1890inline void set_logits(BranchHandle handle, std::span<const float> data, BranchStore& s) {
1891 BranchState* state = s.get(handle);
1892 if (!state) {
1893 throw std::runtime_error("set_logits: invalid handle");
1894 }
1895 if (state->n_vocab <= 0) {
1896 throw std::runtime_error("set_logits: invalid vocab size");
1897 }
1898 if (static_cast<int32_t>(data.size()) != state->n_vocab) {
1899 throw std::runtime_error(
1900 "set_logits: data length (" + std::to_string(data.size()) +
1901 ") does not match branch n_vocab (" + std::to_string(state->n_vocab) + ")");
1902 }
1903
1904 std::memcpy(state->logits_snapshot.data(), data.data(),
1905 state->n_vocab * sizeof(float));
1906 state->has_logits = true;
1907}
1908
1923inline llama_token sample(BranchHandle handle, BranchStore& s) {
1924
1925
1926 BranchState* state = s.get(handle);
1927 llama_sampler* chain = state ? s.get_sampler_chain(state->sampler_chain) : nullptr;
1928 if (!state || !chain) {
1929 return -1;
1930 }
1931
1932 // Must have logits captured before sampling
1933 if (!state->has_logits) {
1934 LLOYAL_LOG_DEBUG("[branch::sample] No logits captured - call prefill()/step() first");
1935 return -1;
1936 }
1937
1938 // Reuse pre-allocated candidates buffer (avoids O(n_vocab) allocs per sample)
1939 for (int i = 0; i < state->n_vocab; i++) {
1940 state->candidates_buffer[i] = llama_token_data{
1941 static_cast<llama_token>(i),
1942 state->logits_snapshot[i],
1943 0.0f};
1944 }
1945
1946 llama_token_data_array cur_p = {
1947 state->candidates_buffer.data(),
1948 static_cast<size_t>(state->n_vocab),
1949 -1,
1950 false};
1951
1952 // Apply grammar first if present (via anti-corruption layer)
1953 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
1954 if (gram) {
1955 grammar::apply(gram, &cur_p);
1956 }
1957
1958 // Apply logit bias if present — O(n_biases) via direct index
1959 // (candidates are in token-ID order; grammar::apply preserves order)
1960 if (!state->logit_bias.empty()) {
1961 for (const auto& bias : state->logit_bias) {
1962 if (bias.token >= 0 && bias.token < state->n_vocab) {
1963 cur_p.data[bias.token].logit += bias.bias;
1964 }
1965 }
1966 }
1967
1968 // Apply steer callback if present
1969 if (state->steer_fn) {
1970 try {
1971 state->steer_fn(cur_p);
1972 } catch (const std::exception& e) {
1973 LLOYAL_LOG_DEBUG("[branch::sample] Steer exception: %s", e.what());
1974 // Continue sampling without steer on exception
1975 }
1976 }
1977
1978 // Apply sampler chain (via anti-corruption layer)
1979 sampler::apply(chain, &cur_p);
1980
1981 if (cur_p.selected == -1) {
1982 return -1;
1983 }
1984
1985 llama_token token = cur_p.data[cur_p.selected].id;
1986
1987 // Capture filtered candidates for sampling metrics
1988 // After BOTH grammar and sampler chain - this is the actual sampling distribution
1989 state->last_token = token;
1990 state->last_candidates.clear();
1991 state->last_candidates.reserve(cur_p.size);
1992
1993 for (size_t i = 0; i < cur_p.size; i++) {
1994 state->last_candidates.push_back(cur_p.data[i]);
1995 }
1996
1997 return token;
1998}
1999
2015inline void accept_token(
2016 BranchHandle handle,
2017 llama_token token,
2018 BranchStore& s) {
2019
2020
2021 BranchState* state = s.get(handle);
2022 if (!state) return;
2023
2024 // Accept in grammar (via anti-corruption layer)
2025 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2026 if (gram) {
2027 grammar::accept(gram, token);
2028 }
2029
2030 // Accept in sampler chain for penalty tracking (via anti-corruption layer)
2031 llama_sampler* chain = s.get_sampler_chain(state->sampler_chain);
2032 if (chain) {
2033 sampler::accept(chain, token);
2034 }
2035
2036 // Update model-level perplexity (from raw logits)
2037 // Guard on has_logits to avoid computing surprisal from zero-filled buffer
2038 if (state->metrics != 0 && state->has_logits) {
2039 float ms = metrics::model_surprisal(
2040 state->logits_snapshot.data(), state->n_vocab, token);
2041 s.add_model_surprisal(state->metrics, ms);
2042 }
2043
2044 // Update sampling-level perplexity (from filtered candidates)
2045 if (state->metrics != 0 && !state->last_candidates.empty() &&
2046 token == state->last_token) {
2047 // Extract filtered logits and IDs from candidates
2048 std::vector<float> candidate_logits;
2049 std::vector<int32_t> candidate_ids;
2050 candidate_logits.reserve(state->last_candidates.size());
2051 candidate_ids.reserve(state->last_candidates.size());
2052
2053 for (const auto& cand : state->last_candidates) {
2054 candidate_logits.push_back(cand.logit);
2055 candidate_ids.push_back(cand.id);
2056 }
2057
2058 // Compute sampling-level surprisal
2059 float ss = metrics::sampling_surprisal(
2060 candidate_logits.data(),
2061 candidate_ids.data(),
2062 static_cast<int>(candidate_logits.size()),
2063 token
2064 );
2065 s.add_sampling_surprisal(state->metrics, ss);
2066 }
2067}
2068
2083inline void apply_grammar(
2084 BranchHandle handle,
2085 float* logits,
2086 int n_vocab,
2087 BranchStore& s) {
2088
2089
2090 BranchState* state = s.get(handle);
2091 llama_sampler* gram = state ? s.get_grammar_sampler(state->grammar) : nullptr;
2092 if (!state || !gram) return;
2093
2094 // Use pre-allocated candidates buffer if size matches, otherwise allocate
2095 std::vector<llama_token_data>* candidates_ptr;
2096 std::vector<llama_token_data> temp_buffer;
2097
2098 if (n_vocab == state->n_vocab && !state->candidates_buffer.empty()) {
2099 candidates_ptr = &state->candidates_buffer;
2100 } else {
2101 temp_buffer.resize(n_vocab);
2102 candidates_ptr = &temp_buffer;
2103 }
2104
2105 auto& candidates = *candidates_ptr;
2106 for (int i = 0; i < n_vocab; i++) {
2107 candidates[i] = llama_token_data{
2108 static_cast<llama_token>(i), logits[i], 0.0f};
2109 }
2110
2111 llama_token_data_array cur_p = {
2112 candidates.data(),
2113 static_cast<size_t>(n_vocab),
2114 -1,
2115 false};
2116
2117 grammar::apply(gram, &cur_p);
2118
2119 // Copy masked logits back
2120 for (int i = 0; i < n_vocab; i++) {
2121 logits[i] = candidates[i].logit;
2122 }
2123}
2124
2139inline std::vector<std::pair<llama_token, float>> get_legal_priors(
2140 BranchHandle handle,
2141 BranchStore& s) {
2142
2143
2144 BranchState* state = s.get(handle);
2145 if (!state || !state->has_logits) {
2146 return {};
2147 }
2148
2149 // Reuse pre-allocated candidates buffer
2150 for (int i = 0; i < state->n_vocab; i++) {
2151 state->candidates_buffer[i] = llama_token_data{
2152 static_cast<llama_token>(i),
2153 state->logits_snapshot[i],
2154 0.0f};
2155 }
2156
2157 llama_token_data_array cur_p = {
2158 state->candidates_buffer.data(),
2159 static_cast<size_t>(state->n_vocab),
2160 -1,
2161 false};
2162
2163 // Apply grammar to mask illegal tokens
2164 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2165 if (gram) {
2166 grammar::apply(gram, &cur_p);
2167 }
2168
2169 // Collect legal candidates (logit is finite after grammar masking)
2170 // Grammar masking sets illegal tokens to -INFINITY
2171 std::vector<std::pair<llama_token, float>> legal_priors;
2172 float max_logit = -std::numeric_limits<float>::infinity();
2173
2174 for (size_t i = 0; i < cur_p.size; i++) {
2175 if (std::isfinite(cur_p.data[i].logit)) { // Not masked
2176 legal_priors.emplace_back(cur_p.data[i].id, cur_p.data[i].logit);
2177 if (cur_p.data[i].logit > max_logit) {
2178 max_logit = cur_p.data[i].logit;
2179 }
2180 }
2181 }
2182
2183 if (legal_priors.empty()) {
2184 return {};
2185 }
2186
2187 // Compute softmax over legal moves only (numerically stable)
2188 float sum_exp = 0.0f;
2189 for (auto& [token, logit] : legal_priors) {
2190 float exp_val = std::exp(logit - max_logit);
2191 logit = exp_val; // Temporarily store exp value
2192 sum_exp += exp_val;
2193 }
2194
2195 // Normalize to probabilities
2196 for (auto& [token, prob] : legal_priors) {
2197 prob /= sum_exp;
2198 }
2199
2200 return legal_priors;
2201}
2202
2219
2220
2221 BranchState* state = s.get(handle);
2222 if (!state || !state->has_logits) {
2223 return -std::numeric_limits<float>::infinity();
2224 }
2225
2226 // Reuse pre-allocated candidates buffer
2227 for (int i = 0; i < state->n_vocab; i++) {
2228 state->candidates_buffer[i] = llama_token_data{
2229 static_cast<llama_token>(i),
2230 state->logits_snapshot[i],
2231 0.0f};
2232 }
2233
2234 llama_token_data_array cur_p = {
2235 state->candidates_buffer.data(),
2236 static_cast<size_t>(state->n_vocab),
2237 -1,
2238 false};
2239
2240 // Apply grammar to mask illegal tokens
2241 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2242 if (gram) {
2243 grammar::apply(gram, &cur_p);
2244 }
2245
2246 // Numerically stable logsumexp over legal tokens (finite logits only)
2247 float max_logit = -std::numeric_limits<float>::infinity();
2248 for (size_t i = 0; i < cur_p.size; i++) {
2249 if (std::isfinite(cur_p.data[i].logit) && cur_p.data[i].logit > max_logit) {
2250 max_logit = cur_p.data[i].logit;
2251 }
2252 }
2253
2254 if (!std::isfinite(max_logit)) {
2255 return -std::numeric_limits<float>::infinity(); // No legal tokens
2256 }
2257
2258 float sum_exp = 0.0f;
2259 for (size_t i = 0; i < cur_p.size; i++) {
2260 if (std::isfinite(cur_p.data[i].logit)) {
2261 sum_exp += std::exp(cur_p.data[i].logit - max_logit);
2262 }
2263 }
2264
2265 return max_logit + std::log(sum_exp);
2266}
2267
2279inline bool is_token_legal(
2280 BranchHandle handle,
2281 llama_token token,
2282 BranchStore& s) {
2283
2284
2285 BranchState* state = s.get(handle);
2286 if (!state || token < 0 || token >= state->n_vocab) {
2287 return false;
2288 }
2289
2290 // No grammar = all tokens legal
2291 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2292 if (!gram) {
2293 return true;
2294 }
2295
2296 // Build 1-element candidate array (stack allocated, no heap)
2297 llama_token_data single_candidate = {
2298 token,
2299 state->has_logits ? state->logits_snapshot[token] : 0.0f,
2300 0.0f
2301 };
2302
2303 llama_token_data_array cur_p = {
2304 &single_candidate,
2305 1,
2306 -1,
2307 false
2308 };
2309
2310 // Apply grammar - will set logit to -INFINITY if illegal
2311 grammar::apply(gram, &cur_p);
2312
2313 return std::isfinite(single_candidate.logit);
2314}
2315
2331 BranchHandle handle,
2332 llama_token token,
2333 float logsumexp,
2334 BranchStore& s) {
2335
2336
2337 BranchState* state = s.get(handle);
2338 if (!state || !state->has_logits || token < 0 || token >= state->n_vocab) {
2339 return 0.0f;
2340 }
2341
2342 float logit = state->logits_snapshot[token];
2343 return std::exp(logit - logsumexp);
2344}
2345
2361inline float get_token_prior(
2362 BranchHandle handle,
2363 llama_token token,
2364 float logsumexp,
2365 BranchStore& s) {
2366 if (!is_token_legal(handle, token, s)) {
2367 return 0.0f;
2368 }
2369 return get_token_prior_assume_legal(handle, token, logsumexp, s);
2370}
2371
2372// ===== STATE ACCESSORS =====
2373
2374
2381inline llama_pos get_position(BranchHandle handle, BranchStore& s) {
2382
2383 const BranchState* state = s.get(handle);
2384 return state ? state->position : -1;
2385}
2386
2393inline llama_pos get_fork_head(BranchHandle handle, BranchStore& s) {
2394 const BranchState* state = s.get(handle);
2395 return state ? state->fork_head : 0;
2396}
2397
2409inline float get_perplexity(BranchHandle handle, BranchStore& s) {
2410
2411 const BranchState* state = s.get(handle);
2412 if (!state || state->metrics == 0) {
2413 return std::numeric_limits<float>::infinity();
2414 }
2415 return s.get_model_ppl(state->metrics);
2416}
2417
2430
2431 const BranchState* state = s.get(handle);
2432 if (!state || state->metrics == 0) {
2433 return std::numeric_limits<float>::infinity();
2434 }
2435 return s.get_sampling_ppl(state->metrics);
2436}
2437
2449
2450 const BranchState* state = s.get(handle);
2451
2452 if (!state || state->last_candidates.empty() || state->last_token < 0) {
2453 return 0.0f;
2454 }
2455
2456 // Extract candidates
2457 std::vector<float> candidate_logits;
2458 std::vector<int32_t> candidate_ids;
2459 candidate_logits.reserve(state->last_candidates.size());
2460 candidate_ids.reserve(state->last_candidates.size());
2461
2462 for (const auto& cand : state->last_candidates) {
2463 candidate_logits.push_back(cand.logit);
2464 candidate_ids.push_back(cand.id);
2465 }
2466
2467 // Compute surprisal from filtered distribution
2468 float surprisal = metrics::sampling_surprisal(
2469 candidate_logits.data(),
2470 candidate_ids.data(),
2471 static_cast<int>(candidate_logits.size()),
2472 state->last_token
2473 );
2474
2475 // Convert to probability: P = exp(-surprisal)
2476 return std::exp(-surprisal);
2477}
2478
2485inline int get_n_vocab(BranchHandle handle, BranchStore& s) {
2486
2487 const BranchState* state = s.get(handle);
2488 return state ? state->n_vocab : 0;
2489}
2490
2491
2492// ===== RAII WRAPPER =====
2493
2511class Branch {
2512public:
2513 Branch() : store_(nullptr), handle_(INVALID_HANDLE) {}
2514
2516 : store_(store), handle_(handle) {}
2517
2520 if (handle_ != INVALID_HANDLE && store_) {
2521 branch::pruneSubtree(handle_, *store_);
2522 }
2523 }
2524
2525 Branch(Branch&& other) noexcept
2526 : store_(other.store_), handle_(other.handle_) {
2527 other.handle_ = INVALID_HANDLE;
2528 }
2529
2530 Branch& operator=(Branch&& other) noexcept {
2531 if (this != &other) {
2532 if (handle_ != INVALID_HANDLE && store_) {
2533 branch::pruneSubtree(handle_, *store_);
2534 }
2535 store_ = other.store_;
2536 handle_ = other.handle_;
2537 other.handle_ = INVALID_HANDLE;
2538 }
2539 return *this;
2540 }
2541
2542 Branch(const Branch&) = delete;
2543 Branch& operator=(const Branch&) = delete;
2544
2546 template <SamplingParamsLike P>
2548 llama_context* ctx,
2549 const llama_model* model,
2550 BranchStore& store,
2551 llama_pos start_pos,
2552 const P& params,
2553 int n_batch = DEFAULT_N_BATCH,
2554 const char* grammar_str = nullptr,
2555 boundaries::BoundaryTracker* boundary_tracker = nullptr) {
2556 BranchHandle h = branch::create(ctx, model, store, start_pos, params, n_batch, grammar_str, boundary_tracker);
2557 return Branch(&store, h);
2558 }
2559
2562 BranchHandle h = branch::fork(handle_, *store_);
2563 return Branch(store_, h);
2564 }
2565
2567 void prune() {
2568 branch::prune(handle_, *store_);
2569 handle_ = INVALID_HANDLE;
2570 }
2571
2574 branch::pruneSubtree(handle_, *store_);
2575 handle_ = INVALID_HANDLE;
2576 }
2577
2581 branch::force_snapshot_logits(handle_, *store_);
2582 }
2583
2587 void prefill(const llama_token* tokens, size_t n) {
2588 branch::prefill(handle_, tokens, n, *store_);
2589 }
2590
2593 void step(llama_token token) {
2594 branch::step(handle_, token, *store_);
2595 }
2596
2599 const float* logits() const {
2600 return branch::get_logits(handle_, *store_);
2601 }
2602
2606 llama_token sample() {
2607 return branch::sample(handle_, *store_);
2608 }
2609
2612 void accept(llama_token token) {
2613 branch::accept_token(handle_, token, *store_);
2614 }
2615
2617 bool is_eog(llama_token token) const {
2618 const BranchState* st = store_ ? store_->get(handle_) : nullptr;
2619 return st && st->model ? tokenizer::is_eog(st->model, token) : false;
2620 }
2621
2624 template <SamplingParamsLike P>
2625 void setSamplerParams(const P& params) {
2626 branch::set_sampler_params(handle_, params, *store_);
2627 }
2628
2631 void setGrammar(const char* grammar_str) {
2632 const BranchState* st = store_ ? store_->get(handle_) : nullptr;
2633 branch::set_grammar(handle_, st ? st->model : nullptr, grammar_str, *store_);
2634 }
2635
2636 // ===== ACCESSORS =====
2637
2639 llama_pos position() const { return branch::get_position(handle_, *store_); }
2641 llama_pos forkHead() const { return branch::get_fork_head(handle_, *store_); }
2643 float perplexity() const { return branch::get_perplexity(handle_, *store_); }
2645 int n_vocab() const { return branch::get_n_vocab(handle_, *store_); }
2647 bool valid() const { return handle_ != INVALID_HANDLE; }
2649 BranchHandle handle() const { return handle_; }
2650
2651 // ===== TOPOLOGY =====
2652
2654 BranchHandle parentHandle() const { return store_ ? store_->parent(handle_) : INVALID_HANDLE; }
2656 const std::vector<BranchHandle>& childHandles() const {
2657 static const std::vector<BranchHandle> empty;
2658 return store_ ? store_->children(handle_) : empty;
2659 }
2661 bool isLeaf() const { return store_ ? store_->isLeaf(handle_) : true; }
2663 bool isActive() const { return store_ ? store_->isActive(handle_) : false; }
2664
2665private:
2666 BranchStore* store_;
2667 BranchHandle handle_;
2668};
2669
2670} // namespace lloyal::branch
Stub BoundaryTracker - does nothing.
virtual std::unique_ptr< BoundaryTracker > clone() const
Handle table and batched decode orchestrator for branch management.
Definition branch.hpp:392
GrammarHandle create_grammar_lazy(const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const char *root="root")
Create a lazy grammar (unconstrained until trigger fires)
Definition branch.hpp:757
float get_sampling_ppl(MetricsHandle h) const
Get sampling-level perplexity from a metrics tracker.
Definition branch.hpp:888
bool isActive(BranchHandle h) const
Test whether a branch holds a KV lease.
Definition branch.hpp:626
void free_grammar(GrammarHandle h)
Free a grammar.
Definition branch.hpp:792
GrammarHandle create_grammar(const llama_model *model, const char *grammar_str, const char *root="root")
Create a grammar sampler and register it.
Definition branch.hpp:738
MetricsHandle clone_metrics(MetricsHandle h)
Clone a metrics tracker (for fork)
Definition branch.hpp:824
size_t available() const
Number of vacant seq_ids available for acquisition.
Definition branch.hpp:561
void release(BranchHandle handle)
Release a branch slot + evict its KV lease.
Definition branch.hpp:463
void drain()
Explicit teardown — evict all leases while context is alive.
Definition branch.hpp:512
float get_model_ppl(MetricsHandle h) const
Get model-level perplexity from a metrics tracker.
Definition branch.hpp:874
bool isLeaf(BranchHandle h) const
Test whether a branch is a leaf (no children)
Definition branch.hpp:616
void decode_each(std::span< const DecodeEachItem > items)
Decode one token per branch in a single GPU dispatch.
Definition branch.hpp:921
bool sampler_has_dist(SamplerChainHandle h) const
Check if a sampler chain ends with dist (stochastic) or greedy.
Definition branch.hpp:723
MetricsHandle create_metrics()
Create a metrics tracker and register it.
Definition branch.hpp:813
~BranchStore()
Destructor — frees CPU resources.
Definition branch.hpp:417
SamplerChainHandle clone_sampler(SamplerChainHandle h)
Clone a sampler chain (for fork)
Definition branch.hpp:687
llama_pos fork_head(BranchHandle h) const
Get a branch's fork head (parent position at fork time)
Definition branch.hpp:595
Allocation allocate()
Allocate a branch slot + KV lease atomically.
Definition branch.hpp:439
void add_cells_used(uint32_t n)
Increment cells_used counter (for standalone prefill/step outside BranchStore methods)
Definition branch.hpp:578
const BranchState * get(BranchHandle handle) const
Look up branch state by handle.
Definition branch.hpp:660
void free_sampler(SamplerChainHandle h)
Free a sampler chain.
Definition branch.hpp:703
void retainOnly(BranchHandle winner)
Keep only the winner — nuclear KV + CPU cleanup.
Definition branch.hpp:534
KvPressure kv_pressure() const
KV cache pressure snapshot — O(1), no tree walking.
Definition branch.hpp:570
SamplerChainHandle create_sampler(const P &params)
Create a sampler chain and register it.
Definition branch.hpp:672
llama_sampler * get_grammar_sampler(GrammarHandle h) const
Dereference a grammar handle (non-owning)
Definition branch.hpp:801
void add_sampling_surprisal(MetricsHandle h, float surprisal)
Add sampling-level surprisal to a metrics tracker.
Definition branch.hpp:860
void init_tenancy(llama_context *ctx)
Initialize KV tenancy after context creation.
Definition branch.hpp:499
void merge_logits(BranchHandle dst_handle, std::span< const BranchHandle > expert_handles, float alpha)
Merge experts' logits_snapshot into dst's logits_snapshot.
Definition branch.hpp:1102
GrammarHandle clone_grammar(GrammarHandle h)
Clone a grammar (for fork)
Definition branch.hpp:777
void decode_scatter(std::span< const DecodeScatterItem > items)
Decode variable token counts per branch with auto-chunking.
Definition branch.hpp:993
BranchStore(size_t initial_capacity=16)
Construct a branch store with initial slot capacity.
Definition branch.hpp:398
void free_metrics(MetricsHandle h)
Free a metrics tracker.
Definition branch.hpp:837
const std::vector< BranchHandle > & children(BranchHandle h) const
Get a branch's child handles.
Definition branch.hpp:605
BranchState * get(BranchHandle handle)
Look up branch state by handle.
Definition branch.hpp:640
BranchHandle parent(BranchHandle h) const
Get a branch's parent handle.
Definition branch.hpp:585
void add_model_surprisal(MetricsHandle h, float surprisal)
Add model-level surprisal to a metrics tracker.
Definition branch.hpp:846
llama_sampler * get_sampler_chain(SamplerChainHandle h) const
Dereference a sampler chain handle (non-owning)
Definition branch.hpp:712
~Branch()
Destructor — CASCADE prunes entire subtree.
Definition branch.hpp:2519
void accept(llama_token token)
Accept a token — advance grammar, penalty window, and metrics.
Definition branch.hpp:2612
Branch & operator=(const Branch &)=delete
int n_vocab() const
Vocabulary size.
Definition branch.hpp:2645
void setGrammar(const char *grammar_str)
Replace grammar constraint (nullptr/empty to remove)
Definition branch.hpp:2631
const std::vector< BranchHandle > & childHandles() const
Child branch handles (empty if leaf)
Definition branch.hpp:2656
void setSamplerParams(const P &params)
Replace sampler chain with new parameters (memoized)
Definition branch.hpp:2625
BranchHandle handle() const
Underlying opaque handle (for interop with free functions)
Definition branch.hpp:2649
const float * logits() const
Get the branch's captured logits snapshot.
Definition branch.hpp:2599
bool is_eog(llama_token token) const
Check if a token is end-of-generation for this branch's model.
Definition branch.hpp:2617
BranchHandle parentHandle() const
Parent branch handle, or INVALID_HANDLE if root.
Definition branch.hpp:2654
Branch(BranchStore *store, BranchHandle handle)
Definition branch.hpp:2515
llama_pos position() const
Current decode position (token count)
Definition branch.hpp:2639
llama_token sample()
Sample a token from captured logits.
Definition branch.hpp:2606
void step(llama_token token)
Decode one token and capture logits (generation step)
Definition branch.hpp:2593
void pruneSubtree()
CASCADE prune — removes entire subtree.
Definition branch.hpp:2573
float perplexity() const
Model-level perplexity (from raw logits, pre-filter)
Definition branch.hpp:2643
Branch(const Branch &)=delete
bool isActive() const
True if this branch holds a KV lease.
Definition branch.hpp:2663
llama_pos forkHead() const
Parent's position at fork time (0 for root branches)
Definition branch.hpp:2641
void force_snapshot_logits()
Force-copy shared logits buffer into this branch's snapshot.
Definition branch.hpp:2580
bool isLeaf() const
True if this branch has no children.
Definition branch.hpp:2661
void prefill(const llama_token *tokens, size_t n)
Decode multiple tokens and capture logits atomically (prompt injection)
Definition branch.hpp:2587
bool valid() const
True if this Branch holds a valid handle.
Definition branch.hpp:2647
void prune()
RESTRICT prune (throws if children exist)
Definition branch.hpp:2567
Branch fork()
Fork: allocates slot + lease, records topology edge.
Definition branch.hpp:2561
Branch & operator=(Branch &&other) noexcept
Definition branch.hpp:2530
static Branch create(llama_context *ctx, const llama_model *model, BranchStore &store, llama_pos start_pos, const P &params, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Factory: allocates slot + lease from store.
Definition branch.hpp:2547
Branch(Branch &&other) noexcept
Definition branch.hpp:2525
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:48
Batch Decoding Operations.
Grammar-Constrained Sampling.
constexpr llama_seq_id NO_LEASE
Sentinel value indicating a branch has no KV residency.
Definition kv.hpp:207
KV Cache Physics.
Zero-copy logits access with clear lifetime semantics.
Distribution Metrics for Test-Time Alignment.
void prefill(BranchHandle handle, const llama_token *tokens, size_t n_tokens, BranchStore &s)
Decode multiple tokens and capture logits atomically (prompt prefill)
Definition branch.hpp:1776
void set_logit_bias(BranchHandle handle, const llama_logit_bias *biases, size_t n_biases, BranchStore &s)
Definition branch.hpp:1460
float get_perplexity(BranchHandle handle, BranchStore &s)
Get model-level perplexity (from raw logits)
Definition branch.hpp:2409
void apply_grammar(BranchHandle handle, float *logits, int n_vocab, BranchStore &s)
Apply grammar constraints to an external logits buffer.
Definition branch.hpp:2083
int32_t GrammarHandle
Handle to a grammar sampler in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:154
uint32_t BranchHandle
Opaque handle to a branch slot.
Definition branch.hpp:95
void step(BranchHandle handle, llama_token token, BranchStore &s)
Decode a single token and capture logits (generation step)
Definition branch.hpp:1821
constexpr int DEFAULT_N_BATCH
Default batch size for decode operations.
Definition branch.hpp:99
void set_grammar_lazy(BranchHandle handle, const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, BranchStore &s)
Set lazy grammar on a branch (unconstrained until trigger fires)
Definition branch.hpp:1656
void set_logits(BranchHandle handle, std::span< const float > data, BranchStore &s)
Overwrite a branch's logits_snapshot with caller-provided values.
Definition branch.hpp:1890
void set_steer(BranchHandle handle, std::function< void(llama_token_data_array &)> steer_fn, BranchStore &s)
Definition branch.hpp:1527
constexpr uint32_t GEN_SHIFT
Bit shift for generation field.
Definition branch.hpp:100
void accept_token(BranchHandle handle, llama_token token, BranchStore &s)
Accept a sampled token, advancing grammar and sampler state.
Definition branch.hpp:2015
void force_snapshot_logits(BranchHandle handle, BranchStore &s)
Force-copy the shared llama.cpp logits buffer into this branch's private snapshot.
Definition branch.hpp:1740
llama_pos get_fork_head(BranchHandle handle, BranchStore &s)
Get the branch's fork head (parent position at fork time)
Definition branch.hpp:2393
float get_sampling_perplexity(BranchHandle handle, BranchStore &s)
Get sampling-level perplexity (from filtered distribution)
Definition branch.hpp:2429
float get_legal_logsumexp(BranchHandle handle, BranchStore &s)
Compute log-sum-exp over grammar-legal logits.
Definition branch.hpp:2218
void clear_logit_bias(BranchHandle handle, BranchStore &s)
Clear all logit biases from a branch.
Definition branch.hpp:1486
uint16_t handle_generation(BranchHandle h)
Extract generation counter from a branch handle.
Definition branch.hpp:134
const float * get_logits(BranchHandle handle, BranchStore &s)
Get the branch's captured logits snapshot.
Definition branch.hpp:1860
void set_sampler_params(BranchHandle handle, const P &params, BranchStore &s)
Replace a branch's sampler chain with new parameters.
Definition branch.hpp:1581
void prune(BranchHandle handle, BranchStore &s)
Prune a leaf branch (RESTRICT — throws if children exist)
Definition branch.hpp:1692
BranchHandle create(llama_context *ctx, const llama_model *model, BranchStore &s, llama_pos start_pos, const P &params, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Create a new branch with sampler chain, optional grammar, and metrics.
Definition branch.hpp:1290
CachedSamplingParams snapshot_params(const P &p)
Snapshot sampling params for memoization comparison.
Definition branch.hpp:237
constexpr uint32_t INDEX_MASK
Mask for slot index field.
Definition branch.hpp:101
BranchHandle fork(BranchHandle source, BranchStore &s)
Fork a branch into a new independent sequence.
Definition branch.hpp:1365
int get_n_vocab(BranchHandle handle, BranchStore &s)
Get the branch's vocabulary size.
Definition branch.hpp:2485
llama_token sample(BranchHandle handle, BranchStore &s)
Sample a token from the branch's captured logits.
Definition branch.hpp:1923
void pruneSubtree(BranchHandle h, BranchStore &s)
Prune a branch and all descendants (CASCADE — iterative post-order)
Definition branch.hpp:1709
constexpr BranchHandle INVALID_HANDLE
Null handle sentinel.
Definition branch.hpp:97
float get_token_prior(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token, checking grammar legality first.
Definition branch.hpp:2361
constexpr llama_seq_id NO_LEASE
Branch has no KV residency.
Definition branch.hpp:98
int32_t MetricsHandle
Handle to a metrics tracker in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:157
llama_pos get_position(BranchHandle handle, BranchStore &s)
Get the branch's current decode position.
Definition branch.hpp:2381
void set_grammar(BranchHandle handle, const llama_model *model, const char *grammar_str, BranchStore &s)
Replace a branch's grammar constraint.
Definition branch.hpp:1617
void clear_steer(BranchHandle handle, BranchStore &s)
Clear the steer callback from a branch.
Definition branch.hpp:1550
float get_token_prior_assume_legal(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token known to be grammar-legal.
Definition branch.hpp:2330
BranchHandle make_handle(uint16_t index, uint16_t generation)
Construct a branch handle from index and generation.
Definition branch.hpp:144
uint16_t handle_index(BranchHandle h)
Extract slot index from a branch handle.
Definition branch.hpp:125
int32_t SamplerChainHandle
Handle to a sampler chain in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:151
std::vector< std::pair< llama_token, float > > get_legal_priors(BranchHandle handle, BranchStore &s)
Get grammar-legal tokens with renormalized probabilities.
Definition branch.hpp:2139
bool is_token_legal(BranchHandle handle, llama_token token, BranchStore &s)
Check if a token is legal under grammar constraints.
Definition branch.hpp:2279
float get_last_sampling_prior(BranchHandle handle, BranchStore &s)
Get the last sampled token's prior from the filtered distribution.
Definition branch.hpp:2448
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
Definition decode.hpp:481
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
Definition decode.hpp:239
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
Definition decode.hpp:125
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
Definition decode.hpp:396
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
Definition decode.hpp:340
constexpr T as_value(const X &x, T def)
Extract value from either T or std::optional<T> with fallback.
Definition sampler.hpp:51
void free_sampler(llama_sampler *smpl)
Free a grammar sampler.
Definition grammar.hpp:235
llama_sampler * clone_sampler(llama_sampler *smpl)
Clone a grammar sampler (for fork/branching).
Definition grammar.hpp:213
llama_sampler * init_sampler(const llama_model *model, const std::string &grammar_str, const std::string &root_rule="root")
Initialize a grammar sampler from GBNF grammar string.
Definition grammar.hpp:107
llama_sampler * init_lazy_sampler(const llama_model *model, const std::string &grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const std::string &root_rule="root")
Initialize a lazy grammar sampler from GBNF grammar string.
Definition grammar.hpp:158
void accept(llama_sampler *smpl, llama_token token)
Accept a token into grammar state.
Definition grammar.hpp:263
void apply(llama_sampler *smpl, llama_token_data_array *cur_p)
Apply grammar constraint to candidates.
Definition grammar.hpp:249
void evict_all(State &s)
Evict every leased seq_id.
Definition kv.hpp:351
llama_seq_id acquire(State &s)
Acquire a seq_id from the vacant pool.
Definition kv.hpp:256
size_t available(const State &s)
Number of vacant seq_ids available for acquisition.
Definition kv.hpp:365
void evict(State &s, llama_seq_id seq)
Evict a seq_id — strip all KV tags then release.
Definition kv.hpp:297
void retain(State &s, llama_seq_id keep)
Nuclear retain — keep one seq, rebuild vacancy from scratch.
Definition kv.hpp:320
State init(llama_context *ctx, llama_seq_id n_seq_max)
Initialize tenancy with all seq_ids vacant.
Definition kv.hpp:234
void release(State &s, llama_seq_id seq)
Release a seq_id back to vacant — bookkeeping only, no KV calls.
Definition kv.hpp:275
void seq_cp(llama_context *ctx, llama_seq_id src, llama_seq_id dst, llama_pos p0=0, llama_pos p1=-1)
Copy KV cache from one sequence to another.
Definition kv.hpp:138
llama_pos pos_max(llama_context *ctx, llama_seq_id seq)
Get maximum position in KV cache sequence.
Definition kv.hpp:111
float * get(llama_context *ctx, int32_t index=-1)
Definition logits.hpp:79
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
Definition metrics.hpp:227
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
Definition metrics.hpp:132
void apply(llama_sampler *chain, llama_token_data_array *cur_p)
Apply a sampler chain to a candidate array.
Definition sampler.hpp:582
void accept(llama_sampler *chain, llama_token token)
Accept a token into the sampler chain.
Definition sampler.hpp:596
llama_sampler * clone_chain(llama_sampler *chain)
Clone a sampler chain.
Definition sampler.hpp:527
llama_sampler * create_chain(const P &params)
Create a persistent sampler chain from parameters.
Definition sampler.hpp:466
void free_chain(llama_sampler *chain)
Free a sampler chain.
Definition sampler.hpp:568
bool is_eog(const llama_vocab *vocab, llama_token token)
Check if token is end-of-generation marker.
Token Sampling Operations.
Consolidated mutable state for a single branch.
Definition branch.hpp:275
bool in_use
True when slot is allocated to an active branch.
Definition branch.hpp:315
const llama_model * model
Llama model (not owned, must outlive branch)
Definition branch.hpp:277
bool has_logits
True only after force_snapshot_logits(), prefill(), or step()
Definition branch.hpp:299
std::function< void(llama_token_data_array &)> steer_fn
Dynamic logit callback, NOT cloned on fork.
Definition branch.hpp:291
int n_batch
Batch size for decode operations.
Definition branch.hpp:311
llama_seq_id seq_id
KV cache sequence identifier (NO_LEASE when inactive)
Definition branch.hpp:279
llama_pos position
Current decode position in the sequence.
Definition branch.hpp:280
std::vector< llama_logit_bias > logit_bias
Static token biases, cloned on fork.
Definition branch.hpp:290
std::vector< llama_token_data > last_candidates
Filtered candidates from last sample()
Definition branch.hpp:296
std::vector< float > logits_snapshot
Captured logit distribution (n_vocab floats)
Definition branch.hpp:298
std::vector< llama_token_data > candidates_buffer
Reusable scratch buffer for sampling (avoids O(n_vocab) allocs per sample call).
Definition branch.hpp:309
boundaries::BoundaryTracker * boundary_tracker
Token boundary detector (owned, optional)
Definition branch.hpp:288
SamplerChainHandle sampler_chain
Handle into BranchStore's sampler registry.
Definition branch.hpp:283
llama_context * ctx
Llama context (not owned, must outlive branch)
Definition branch.hpp:276
BranchHandle parent
Parent branch (INVALID_HANDLE if root)
Definition branch.hpp:318
int n_vocab
Vocabulary size (cached for buffer pre-allocation)
Definition branch.hpp:312
GrammarHandle grammar
Handle into BranchStore's grammar registry.
Definition branch.hpp:284
std::vector< BranchHandle > children
Child branches forked from this one.
Definition branch.hpp:319
llama_token last_token
Last token returned by sample()
Definition branch.hpp:295
MetricsHandle metrics
Handle into BranchStore's metrics registry.
Definition branch.hpp:293
llama_pos fork_head
Parent's position at fork time (0 for root branches)
Definition branch.hpp:281
CachedSamplingParams cached_params
Params used to create current chain (for memoization)
Definition branch.hpp:286
uint16_t generation
Slot generation counter (for ABA prevention)
Definition branch.hpp:314
Result of allocate(): a slot handle + its leased seq_id.
Definition branch.hpp:429
Concrete sampling params snapshot for memoization.
Definition branch.hpp:215
bool operator==(const CachedSamplingParams &) const =default
Item for decode_each: one token per branch.
Definition branch.hpp:330
Item for decode_scatter: variable tokens per branch.
Definition branch.hpp:348
std::span< const llama_token > tokens
Definition branch.hpp:350
RAII entry for a grammar sampler in the registry.
Definition branch.hpp:190
GrammarEntry(GrammarEntry &&o) noexcept
Definition branch.hpp:196
GrammarEntry & operator=(const GrammarEntry &)=delete
GrammarEntry & operator=(GrammarEntry &&o) noexcept
Definition branch.hpp:197
GrammarEntry(const GrammarEntry &)=delete
llama_sampler * sampler
Definition branch.hpp:191
Snapshot of KV cache pressure from BranchStore.
Definition branch.hpp:114
uint32_t remaining
n_ctx - cells_used (clamped to 0)
Definition branch.hpp:117
uint32_t n_ctx
Total KV capacity.
Definition branch.hpp:115
uint32_t cells_used
Cells allocated since last reset.
Definition branch.hpp:116
RAII entry for a sampler chain in the registry.
Definition branch.hpp:165
SamplerChainEntry(const SamplerChainEntry &)=delete
SamplerChainEntry & operator=(SamplerChainEntry &&o) noexcept
Definition branch.hpp:174
SamplerChainEntry(SamplerChainEntry &&o) noexcept
Definition branch.hpp:172
bool has_dist
True if chain ends with dist (temp > 0), false if greedy.
Definition branch.hpp:167
SamplerChainEntry & operator=(const SamplerChainEntry &)=delete
llama_context * ctx
Context for KV operations (nullptr after drain)
Definition kv.hpp:218
Unified model + sampling perplexity tracker.
Definition metrics.hpp:100