liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
branch.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: LicenseRef-FSL-1.1-Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6
57#include "boundaries.hpp"
58#include "common.hpp"
59#include "decode.hpp"
60#include "grammar.hpp"
61#include "kv.hpp"
62#include "logits.hpp"
63#include "metrics.hpp"
64#include "sampler.hpp"
65
66#include <llama/llama.h>
67#include <algorithm> // std::remove
68#include <cassert> // assert
69#include <cmath> // std::exp, std::log, std::isinf, std::isfinite
70#include <cstdint>
71#include <cstring> // std::memcpy
72#include <ctime> // std::time
73#include <deque> // std::deque (pointer stability for BranchStore)
74#include <functional> // std::function
75#include <limits> // std::numeric_limits
76#include <mutex>
77#include <span> // std::span (C++20)
78#include <stdexcept> // std::runtime_error
79#include <string> // std::to_string
80#include <utility> // std::pair, std::exchange
81#include <vector>
82
83namespace lloyal::branch {
84
85// ===== HANDLE TYPE =====
86
95using BranchHandle = uint32_t;
96
98constexpr llama_seq_id NO_LEASE = kv::NO_LEASE;
99constexpr int DEFAULT_N_BATCH = 512;
100constexpr uint32_t GEN_SHIFT = 16;
101constexpr uint32_t INDEX_MASK = 0xFFFF;
102
115 uint32_t n_ctx;
116 uint32_t cells_used;
117 uint32_t remaining;
118};
119
125inline uint16_t handle_index(BranchHandle h) {
126 return static_cast<uint16_t>(h & INDEX_MASK);
127}
128
134inline uint16_t handle_generation(BranchHandle h) {
135 return static_cast<uint16_t>(h >> GEN_SHIFT);
136}
137
144inline BranchHandle make_handle(uint16_t index, uint16_t generation) {
145 return (static_cast<uint32_t>(generation) << GEN_SHIFT) | index;
146}
147
148// ===== REGISTRY HANDLE TYPES =====
149
151using SamplerChainHandle = int32_t;
152
154using GrammarHandle = int32_t;
155
157using MetricsHandle = int32_t;
158
166 llama_sampler* chain = nullptr;
167 bool has_dist = false;
168
169 SamplerChainEntry() = default;
171
173 : chain(o.chain), has_dist(o.has_dist) { o.chain = nullptr; }
175 if (this != &o) {
177 chain = o.chain; has_dist = o.has_dist; o.chain = nullptr;
178 }
179 return *this;
180 }
183};
184
191 llama_sampler* sampler = nullptr;
192
193 GrammarEntry() = default;
195
196 GrammarEntry(GrammarEntry&& o) noexcept : sampler(o.sampler) { o.sampler = nullptr; }
198 if (this != &o) {
200 sampler = o.sampler; o.sampler = nullptr;
201 }
202 return *this;
203 }
204 GrammarEntry(const GrammarEntry&) = delete;
206};
207
216 float temperature = 0.8f;
217 int32_t top_k = 40;
218 float top_p = 0.95f;
219 float typical_p = 1.0f;
220 float min_p = 0.05f;
221 float penalty_repeat = 1.0f;
222 float penalty_freq = 0.0f;
223 float penalty_present = 0.0f;
224 int32_t penalty_last_n = 64;
225 uint32_t seed = 0;
226 bool operator==(const CachedSamplingParams&) const = default;
227};
228
236template <SamplingParamsLike P>
238 using ::lloyal::detail::as_value;
240 as_value(p.temperature, 0.8f),
241 as_value(p.top_k, static_cast<int32_t>(40)),
242 as_value(p.top_p, 0.95f),
243 as_value(p.typical_p, 1.0f),
244 as_value(p.min_p, 0.05f),
245 as_value(p.penalty_repeat, 1.0f),
246 as_value(p.penalty_freq, 0.0f),
247 as_value(p.penalty_present, 0.0f),
248 as_value(p.penalty_last_n, static_cast<int32_t>(64)),
249 as_value(p.seed, static_cast<uint32_t>(0)),
250 };
251}
252
253// ===== BRANCH STATE =====
254
276 llama_context* ctx = nullptr;
277 const llama_model* model = nullptr;
278
279 llama_seq_id seq_id = NO_LEASE;
280 llama_pos position = 0;
281 llama_pos fork_head = 0;
282
285
287
289
290 std::vector<llama_logit_bias> logit_bias;
291 std::function<void(llama_token_data_array&)> steer_fn;
292
294
295 llama_token last_token = -1;
296 std::vector<llama_token_data> last_candidates;
297
298 std::vector<float> logits_snapshot;
299 bool has_logits = false;
300
309 std::vector<llama_token_data> candidates_buffer;
310
312 int n_vocab = 0;
313
314 uint16_t generation = 0;
315 bool in_use = false;
316
317 // Topology — maintained by fork/prune/pruneSubtree
319 std::vector<BranchHandle> children;
320};
321
322// ===== BATCHED DECODE ITEM TYPES =====
323
334
350 std::span<const llama_token> tokens;
351};
352
353// ===== BRANCH STORE (HANDLE TABLE) =====
354
393public:
398 explicit BranchStore(size_t initial_capacity = 16) {
399 // Ensure minimum capacity of 2 (slot 0 reserved + at least 1 usable)
400 if (initial_capacity < 2) {
401 initial_capacity = 2;
402 }
403 slots_.resize(initial_capacity);
404 // Slot 0 is reserved (handle 0 = invalid)
405 slots_[0].in_use = true;
406 slots_[0].generation = 0xFFFF; // Never valid
407
408 // Initialize freelist with remaining slots
409 // NOTE: Use i-- > 1 pattern to avoid size_t underflow
410 for (size_t i = initial_capacity; i-- > 1; ) {
411 freelist_.push_back(static_cast<uint16_t>(i));
412 }
413 }
414
418 if (tenancy_.ctx != nullptr) {
419 LLOYAL_LOG_DEBUG("[BranchStore] WARNING: not drained before destruction");
420 }
421 for (size_t i = 1; i < slots_.size(); ++i) {
422 if (slots_[i].in_use) {
423 free_branch_resources(slots_[i]);
424 }
425 }
426 }
427
429 struct Allocation { BranchHandle handle; llama_seq_id seq_id; };
430
440 if (tenancy_.ctx == nullptr) return {INVALID_HANDLE, NO_LEASE}; // drained
441 llama_seq_id seq = kv::tenancy::acquire(tenancy_);
442 if (seq < 0) return {INVALID_HANDLE, NO_LEASE};
443 BranchHandle handle = allocate_slot();
444 if (handle == INVALID_HANDLE) {
445 // Seq was never used — bookkeeping-only return, no KV calls
446 kv::tenancy::release(tenancy_, seq);
447 return {INVALID_HANDLE, NO_LEASE};
448 }
449 // Stamp seq_id on slot so release() can evict properly
450 BranchState* st = get(handle);
451 st->seq_id = seq;
452 return {handle, seq};
453 }
454
463 void release(BranchHandle handle) {
464 if (handle == INVALID_HANDLE) return;
465 BranchState* st = get(handle);
466 if (!st) return;
467 // Eager edge cleanup: remove from parent's children
468 if (st->parent != INVALID_HANDLE) {
469 BranchState* p = get(st->parent);
470 if (p) {
471 auto& c = p->children;
472 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
473 }
474 }
475 // Subtract unique cells owned by this branch (above fork_head)
476 if (st->position > st->fork_head) {
477 uint32_t unique = static_cast<uint32_t>(st->position - st->fork_head);
478 cells_used_ = (unique <= cells_used_) ? cells_used_ - unique : 0;
479 }
480 // Evict lease (KV strip + bookkeeping)
481 if (st->seq_id != NO_LEASE)
482 kv::tenancy::evict(tenancy_, st->seq_id);
483 free_branch_resources(*st);
484 reset_slot(*st);
485 freelist_.push_back(handle_index(handle));
486
487 // All branches released → KV cache empty, reset pressure counter
488 if (freelist_.size() == slots_.size() - 1) {
489 cells_used_ = 0;
490 }
491 }
492
493 // ===== TENANCY LIFECYCLE =====
494
499 void init_tenancy(llama_context* ctx) {
500 tenancy_ = kv::tenancy::init(ctx, llama_n_seq_max(ctx));
501 cells_used_ = 0;
502 }
503
512 void drain() {
513 if (tenancy_.ctx == nullptr) return; // idempotent
514 kv::tenancy::evict_all(tenancy_);
515 for (size_t i = 1; i < slots_.size(); ++i) {
516 if (slots_[i].in_use) {
517 free_branch_resources(slots_[i]);
518 reset_slot(slots_[i]);
519 }
520 }
521 tenancy_.ctx = nullptr; // marks as drained
522 cells_used_ = 0;
523 }
524
535 BranchState* w = get(winner);
536 if (!w) throw std::runtime_error("retainOnly: invalid winner handle");
537 if (w->seq_id == NO_LEASE) throw std::runtime_error("retainOnly: winner has no lease");
538 kv::tenancy::retain(tenancy_, w->seq_id); // nuclear KV pass
539 // Collect losers first — don't mutate while iterating
540 std::vector<BranchHandle> losers;
541 for (size_t i = 1; i < slots_.size(); ++i) {
542 if (!slots_[i].in_use) continue;
543 BranchHandle h = make_handle(static_cast<uint16_t>(i), slots_[i].generation);
544 if (h == winner) continue;
545 losers.push_back(h);
546 }
547 for (auto h : losers)
548 release_slot_only(h); // CPU only, KV already stripped
550 w->fork_head = 0;
551 w->children.clear();
552 cells_used_ = static_cast<uint32_t>(w->position);
553 }
554
555 // ===== TOPOLOGY QUERIES =====
556
561 size_t available() const { return kv::tenancy::available(tenancy_); }
562
571 if (!tenancy_.ctx) return {0, 0, 0};
572 uint32_t n_ctx = static_cast<uint32_t>(llama_n_ctx(tenancy_.ctx));
573 uint32_t remaining = (n_ctx > cells_used_) ? n_ctx - cells_used_ : 0;
574 return { n_ctx, cells_used_, remaining };
575 }
576
578 void add_cells_used(uint32_t n) { cells_used_ += n; }
579
586 const BranchState* st = get(h);
587 return st ? st->parent : INVALID_HANDLE;
588 }
589
595 llama_pos fork_head(BranchHandle h) const {
596 const BranchState* st = get(h);
597 return st ? st->fork_head : 0;
598 }
599
605 const std::vector<BranchHandle>& children(BranchHandle h) const {
606 static const std::vector<BranchHandle> empty;
607 const BranchState* st = get(h);
608 return st ? st->children : empty;
609 }
610
616 bool isLeaf(BranchHandle h) const {
617 const BranchState* st = get(h);
618 return st ? st->children.empty() : true;
619 }
620
626 bool isActive(BranchHandle h) const {
627 const BranchState* st = get(h);
628 return st ? (st->seq_id != NO_LEASE) : false;
629 }
630
641 if (handle == INVALID_HANDLE) return nullptr;
642
643 uint16_t index = handle_index(handle);
644 uint16_t gen = handle_generation(handle);
645
646 // Slot 0 is reserved and never valid for external use
647 if (index == 0) return nullptr;
648
649 if (index >= slots_.size()) return nullptr;
650
651 BranchState& slot = slots_[index];
652 if (!slot.in_use || slot.generation != gen) {
653 return nullptr;
654 }
655
656 return &slot;
657 }
658
660 const BranchState* get(BranchHandle handle) const {
661 return const_cast<BranchStore*>(this)->get(handle);
662 }
663
664 // ===== SAMPLER CHAIN REGISTRY =====
665
671 template <SamplingParamsLike P>
673 SamplerChainHandle h = next_sampler_handle_++;
674 SamplerChainEntry entry;
675 entry.chain = sampler::create_chain(params);
676 float temperature = ::lloyal::detail::as_value(params.temperature, 0.8f);
677 entry.has_dist = (temperature > 0.0f);
678 sampler_chains_.emplace(h, std::move(entry));
679 return h;
680 }
681
688 if (h == 0) return 0;
689 auto it = sampler_chains_.find(h);
690 if (it == sampler_chains_.end()) return 0;
691 SamplerChainHandle nh = next_sampler_handle_++;
692 SamplerChainEntry entry;
693 entry.chain = sampler::clone_chain(it->second.chain);
694 entry.has_dist = it->second.has_dist;
695 sampler_chains_.emplace(nh, std::move(entry));
696 return nh;
697 }
698
704 if (h != 0) sampler_chains_.erase(h);
705 }
706
712 llama_sampler* get_sampler_chain(SamplerChainHandle h) const {
713 if (h == 0) return nullptr;
714 auto it = sampler_chains_.find(h);
715 return it != sampler_chains_.end() ? it->second.chain : nullptr;
716 }
717
724 if (h == 0) return false;
725 auto it = sampler_chains_.find(h);
726 return it != sampler_chains_.end() ? it->second.has_dist : false;
727 }
728
729 // ===== GRAMMAR REGISTRY =====
730
738 GrammarHandle create_grammar(const llama_model* model,
739 const char* grammar_str,
740 const char* root = "root") {
741 GrammarHandle h = next_grammar_handle_++;
742 GrammarEntry entry;
743 entry.sampler = grammar::init_sampler(model, grammar_str, root);
744 grammars_.emplace(h, std::move(entry));
745 return h;
746 }
747
758 const llama_model* model,
759 const char* grammar_str,
760 const std::vector<std::string>& trigger_patterns,
761 const std::vector<llama_token>& trigger_tokens,
762 const char* root = "root") {
763 GrammarHandle h = next_grammar_handle_++;
764 GrammarEntry entry;
766 model, grammar_str, trigger_patterns, trigger_tokens, root);
767 if (!entry.sampler) return 0;
768 grammars_.emplace(h, std::move(entry));
769 return h;
770 }
771
778 if (h == 0) return 0;
779 auto it = grammars_.find(h);
780 if (it == grammars_.end()) return 0;
781 GrammarHandle nh = next_grammar_handle_++;
782 GrammarEntry entry;
783 entry.sampler = grammar::clone_sampler(it->second.sampler);
784 grammars_.emplace(nh, std::move(entry));
785 return nh;
786 }
787
793 if (h != 0) grammars_.erase(h);
794 }
795
801 llama_sampler* get_grammar_sampler(GrammarHandle h) const {
802 if (h == 0) return nullptr;
803 auto it = grammars_.find(h);
804 return it != grammars_.end() ? it->second.sampler : nullptr;
805 }
806
807 // ===== METRICS REGISTRY =====
808
814 MetricsHandle h = next_metrics_handle_++;
815 metrics_registry_[h] = metrics::BranchMetricsState{};
816 return h;
817 }
818
825 if (h == 0) return 0;
826 auto it = metrics_registry_.find(h);
827 if (it == metrics_registry_.end()) return 0;
828 MetricsHandle nh = next_metrics_handle_++;
829 metrics_registry_[nh] = it->second;
830 return nh;
831 }
832
838 if (h != 0) metrics_registry_.erase(h);
839 }
840
846 void add_model_surprisal(MetricsHandle h, float surprisal) {
847 if (h == 0) return;
848 auto it = metrics_registry_.find(h);
849 if (it == metrics_registry_.end()) return;
850 if (!std::isfinite(surprisal)) return;
851 it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
852 it->second.model.count++;
853 }
854
860 void add_sampling_surprisal(MetricsHandle h, float surprisal) {
861 if (h == 0) return;
862 auto it = metrics_registry_.find(h);
863 if (it == metrics_registry_.end()) return;
864 if (!std::isfinite(surprisal)) return;
865 it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
866 it->second.sampling.count++;
867 }
868
875 if (h == 0) return std::numeric_limits<float>::infinity();
876 auto it = metrics_registry_.find(h);
877 if (it == metrics_registry_.end() || it->second.model.count == 0)
878 return std::numeric_limits<float>::infinity();
879 return std::exp(it->second.model.nll_sum_nats /
880 static_cast<float>(it->second.model.count));
881 }
882
889 if (h == 0) return std::numeric_limits<float>::infinity();
890 auto it = metrics_registry_.find(h);
891 if (it == metrics_registry_.end() || it->second.sampling.count == 0)
892 return std::numeric_limits<float>::infinity();
893 return std::exp(it->second.sampling.nll_sum_nats /
894 static_cast<float>(it->second.sampling.count));
895 }
896
897 // ===== BATCHED DECODE =====
898
921 void decode_each(std::span<const DecodeEachItem> items) {
922 if (items.empty()) return;
923
924 const int32_t n = static_cast<int32_t>(items.size());
925
926 // Validate handles and collect states
927 std::vector<BranchState*> states(n);
928 for (int32_t i = 0; i < n; ++i) {
929 states[i] = get(items[i].handle);
930 if (!states[i]) {
931 throw std::runtime_error("BranchStore::decode_each - invalid handle at index " + std::to_string(i));
932 }
933 if (i > 0 && states[i]->ctx != states[0]->ctx) {
934 throw std::runtime_error("BranchStore::decode_each - all branches must share the same context");
935 }
936 }
937
938 // Build EachItem array from branch states
939 std::vector<decode::EachItem> decode_items(n);
940 for (int32_t i = 0; i < n; ++i) {
941 decode_items[i].token = items[i].token;
942 decode_items[i].pos = states[i]->position;
943 decode_items[i].seq_id = states[i]->seq_id;
944 decode_items[i].output_logits = true;
945 }
946
947 // Single GPU dispatch
948 if (decode::each(states[0]->ctx, decode_items.data(), n, scratch_) != 0) {
949 throw std::runtime_error("BranchStore::decode_each - llama_decode failed");
950 }
951
952 // Capture logits and update positions
953 llama_context* ctx = states[0]->ctx;
954 for (int32_t i = 0; i < n; ++i) {
955 const float* raw_logits = logits::get(ctx, i); // throws on null
956 if (states[i]->n_vocab <= 0) {
957 throw std::runtime_error("BranchStore::decode_each - invalid vocab size at index " + std::to_string(i));
958 }
959 assert(states[i]->logits_snapshot.size() >= static_cast<size_t>(states[i]->n_vocab));
960 std::memcpy(states[i]->logits_snapshot.data(), raw_logits,
961 states[i]->n_vocab * sizeof(float));
962 states[i]->has_logits = true;
963 states[i]->position += 1;
964 }
965 cells_used_ += static_cast<uint32_t>(items.size());
966 }
967
993 void decode_scatter(std::span<const DecodeScatterItem> items) {
994 if (items.empty()) return;
995
996 const int32_t n = static_cast<int32_t>(items.size());
997
998 // Validate handles and collect states
999 std::vector<BranchState*> states(n);
1000 for (int32_t i = 0; i < n; ++i) {
1001 states[i] = get(items[i].handle);
1002 if (!states[i]) {
1003 throw std::runtime_error("BranchStore::decode_scatter - invalid handle at index " + std::to_string(i));
1004 }
1005 if (i > 0 && states[i]->ctx != states[0]->ctx) {
1006 throw std::runtime_error("BranchStore::decode_scatter - all branches must share the same context");
1007 }
1008 }
1009
1010 llama_context* ctx = states[0]->ctx;
1011 const int32_t batch_limit = static_cast<int32_t>(llama_n_batch(ctx));
1012
1013 // Build flat token spans — bin_pack skips empties internally
1014 std::vector<std::span<const llama_token>> spans(n);
1015 for (int32_t i = 0; i < n; ++i) {
1016 spans[i] = items[i].tokens;
1017 }
1018
1019 auto chunks = decode::bin_pack(spans.data(), n, batch_limit);
1020
1021 for (const auto& chunk : chunks) {
1022 if (chunk.oversized) {
1023 int32_t idx = chunk.indices[0];
1024 int32_t tc = static_cast<int32_t>(items[idx].tokens.size());
1025
1026 if (decode::many(ctx, items[idx].tokens.data(), tc,
1027 states[idx]->position, batch_limit,
1028 states[idx]->seq_id) != 0) {
1029 throw std::runtime_error("BranchStore::decode_scatter - decode::many failed for oversized item " + std::to_string(idx));
1030 }
1031
1032 const float* raw_logits = logits::get(ctx, -1);
1033 assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
1034 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1035 states[idx]->n_vocab * sizeof(float));
1036 states[idx]->has_logits = true;
1037 states[idx]->position += tc;
1038 continue;
1039 }
1040
1041 // Normal chunk — build ScatterItems and dispatch
1042 std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
1043 for (size_t k = 0; k < chunk.indices.size(); ++k) {
1044 int32_t idx = chunk.indices[k];
1045 scatter_items[k].tokens = items[idx].tokens;
1046 scatter_items[k].start_pos = states[idx]->position;
1047 scatter_items[k].seq_id = states[idx]->seq_id;
1048 scatter_items[k].output_logits = true;
1049 }
1050
1051 if (decode::scatter(ctx, scatter_items.data(),
1052 static_cast<int32_t>(scatter_items.size()),
1053 scratch_) != 0) {
1054 throw std::runtime_error("BranchStore::decode_scatter - decode::scatter failed");
1055 }
1056
1057 // Capture logits for each item in the chunk
1058 int32_t cursor = 0;
1059 for (size_t k = 0; k < scatter_items.size(); ++k) {
1060 int32_t idx = chunk.indices[k];
1061 int32_t item_n = static_cast<int32_t>(scatter_items[k].tokens.size());
1062
1063 const float* raw_logits = logits::get(ctx, cursor + item_n - 1);
1064 assert(states[idx]->logits_snapshot.size() >= static_cast<size_t>(states[idx]->n_vocab));
1065 std::memcpy(states[idx]->logits_snapshot.data(), raw_logits,
1066 states[idx]->n_vocab * sizeof(float));
1067 states[idx]->has_logits = true;
1068 states[idx]->position += static_cast<int32_t>(items[idx].tokens.size());
1069
1070 cursor += item_n;
1071 }
1072 }
1073
1074 // Accumulate total tokens decoded across all items
1075 for (int32_t i = 0; i < n; ++i) {
1076 cells_used_ += static_cast<uint32_t>(items[i].tokens.size());
1077 }
1078 }
1079
1103 BranchHandle dst_handle,
1104 std::span<const BranchHandle> expert_handles,
1105 float alpha) {
1106 BranchState* dst = get(dst_handle);
1107 if (!dst) {
1108 throw std::runtime_error("BranchStore::merge_logits - invalid dst handle");
1109 }
1110 if (!dst->has_logits) {
1111 throw std::runtime_error("BranchStore::merge_logits - dst has no captured logits");
1112 }
1113 if (dst->n_vocab <= 0) {
1114 throw std::runtime_error("BranchStore::merge_logits - dst n_vocab invalid");
1115 }
1116
1117 const int32_t n_vocab = dst->n_vocab;
1118 float* dst_logits = dst->logits_snapshot.data();
1119
1120 for (size_t i = 0; i < expert_handles.size(); ++i) {
1121 BranchState* src = get(expert_handles[i]);
1122 if (!src) {
1123 throw std::runtime_error(
1124 "BranchStore::merge_logits - invalid expert handle at index " +
1125 std::to_string(i));
1126 }
1127 if (!src->has_logits) {
1128 throw std::runtime_error(
1129 "BranchStore::merge_logits - expert has no captured logits at index " +
1130 std::to_string(i));
1131 }
1132 if (src->n_vocab != n_vocab) {
1133 throw std::runtime_error(
1134 "BranchStore::merge_logits - n_vocab mismatch at index " +
1135 std::to_string(i));
1136 }
1137
1138 const float* src_logits = src->logits_snapshot.data();
1139 for (int32_t t = 0; t < n_vocab; ++t) {
1140 dst_logits[t] += alpha * src_logits[t];
1141 }
1142 }
1143 }
1144
1145private:
1147 void free_branch_resources(BranchState& slot) {
1148 if (slot.sampler_chain != 0) {
1150 slot.sampler_chain = 0;
1151 }
1152 if (slot.grammar != 0) {
1153 free_grammar(slot.grammar);
1154 slot.grammar = 0;
1155 }
1156 if (slot.boundary_tracker) {
1157 delete slot.boundary_tracker;
1158 slot.boundary_tracker = nullptr;
1159 }
1160 if (slot.metrics != 0) {
1161 free_metrics(slot.metrics);
1162 slot.metrics = 0;
1163 }
1164 }
1165
1167 void reset_slot(BranchState& slot) {
1168 slot.in_use = false;
1169 slot.generation = static_cast<uint16_t>(slot.generation + 1); // Prevent ABA
1170 slot.ctx = nullptr;
1171 slot.model = nullptr;
1172 slot.seq_id = NO_LEASE;
1173 slot.position = 0;
1174 slot.fork_head = 0;
1175 slot.sampler_chain = 0;
1176 slot.grammar = 0;
1177 slot.metrics = 0;
1178 slot.cached_params = CachedSamplingParams{};
1179 slot.last_token = -1;
1180 slot.last_candidates.clear();
1181 slot.logits_snapshot.clear();
1182 slot.has_logits = false;
1183 slot.logit_bias.clear();
1184 slot.steer_fn = nullptr;
1185 slot.candidates_buffer.clear();
1186 slot.n_batch = DEFAULT_N_BATCH;
1187 slot.n_vocab = 0;
1188 slot.parent = INVALID_HANDLE;
1189 slot.children.clear();
1190 }
1191
1193 BranchHandle allocate_slot() {
1194 if (freelist_.empty()) {
1195 size_t old_size = slots_.size();
1196 size_t new_size = old_size * 2;
1197 if (new_size > INDEX_MASK) {
1198 new_size = INDEX_MASK + 1;
1199 }
1200 if (old_size >= new_size) {
1201 LLOYAL_LOG_DEBUG("[branch::allocate_slot] Store full, cannot allocate");
1202 return INVALID_HANDLE;
1203 }
1204 slots_.resize(new_size);
1205 for (size_t i = new_size; i-- > old_size; ) {
1206 freelist_.push_back(static_cast<uint16_t>(i));
1207 }
1208 }
1209 uint16_t index = freelist_.back();
1210 freelist_.pop_back();
1211 BranchState& slot = slots_[index];
1212 slot.in_use = true;
1213 return make_handle(index, slot.generation);
1214 }
1215
1218 void release_slot_only(BranchHandle handle) {
1219 if (handle == INVALID_HANDLE) return;
1220 BranchState* st = get(handle);
1221 if (!st) return;
1222 // Eager edge cleanup
1223 if (st->parent != INVALID_HANDLE) {
1224 BranchState* p = get(st->parent);
1225 if (p) {
1226 auto& c = p->children;
1227 c.erase(std::remove(c.begin(), c.end(), handle), c.end());
1228 }
1229 }
1230 free_branch_resources(*st);
1231 reset_slot(*st);
1232 freelist_.push_back(handle_index(handle));
1233 }
1234
1237 std::deque<BranchState> slots_;
1238 std::vector<uint16_t> freelist_;
1239
1243 kv::tenancy::State tenancy_;
1244
1248 uint32_t cells_used_ = 0;
1249
1252 decode::Scratch scratch_;
1253
1254 // ===== Handle registries (instance-scoped, not global static) =====
1255
1256 std::unordered_map<SamplerChainHandle, SamplerChainEntry> sampler_chains_;
1257 SamplerChainHandle next_sampler_handle_ = 1;
1258
1259 std::unordered_map<GrammarHandle, GrammarEntry> grammars_;
1260 GrammarHandle next_grammar_handle_ = 1;
1261
1262 std::unordered_map<MetricsHandle, metrics::BranchMetricsState> metrics_registry_;
1263 MetricsHandle next_metrics_handle_ = 1;
1264};
1265
1266
1267// ===== BRANCH API =====
1268
1289template <SamplingParamsLike P>
1291 llama_context* ctx,
1292 const llama_model* model,
1293 BranchStore& s,
1294 llama_pos start_pos,
1295 const P& params,
1296 int n_batch = DEFAULT_N_BATCH,
1297 const char* grammar_str = nullptr,
1298 boundaries::BoundaryTracker* boundary_tracker = nullptr) {
1299 if (!ctx || !model) {
1300 LLOYAL_LOG_DEBUG("[branch::create] NULL ctx or model");
1301 return INVALID_HANDLE;
1302 }
1303
1304 auto [handle, seq_id] = s.allocate();
1305 if (handle == INVALID_HANDLE) {
1306 return INVALID_HANDLE;
1307 }
1308
1309 BranchState* state = s.get(handle);
1310 if (!state) {
1311 s.release(handle);
1312 return INVALID_HANDLE;
1313 }
1314
1315 state->ctx = ctx;
1316 state->model = model;
1317 // seq_id already stamped by allocate()
1318 state->position = start_pos;
1319 state->n_batch = n_batch;
1320
1321 const llama_vocab* vocab = llama_model_get_vocab(model);
1322 state->n_vocab = llama_vocab_n_tokens(vocab);
1323 state->logits_snapshot.resize(state->n_vocab);
1324 state->has_logits = false;
1325 state->candidates_buffer.resize(state->n_vocab);
1326
1327 state->sampler_chain = s.create_sampler(params);
1328 state->cached_params = snapshot_params(params);
1329
1330 if (grammar_str && grammar_str[0] != '\0') {
1331 state->grammar = s.create_grammar(model, grammar_str);
1332 }
1333
1334 state->boundary_tracker = boundary_tracker;
1335 state->metrics = s.create_metrics();
1336
1337 LLOYAL_LOG_DEBUG("[branch::create] Created branch handle=%u seq=%d pos=%d",
1338 handle, seq_id, start_pos);
1339
1340 return handle;
1341}
1342
1343
1352struct ForkOpts {
1358 bool clone_logits = true;
1359};
1360
1383inline BranchHandle fork(BranchHandle source, BranchStore& s, ForkOpts opts = {}) {
1384 BranchState* src = s.get(source);
1385 if (!src) {
1386 LLOYAL_LOG_DEBUG("[branch::fork] Invalid source handle");
1387 return INVALID_HANDLE;
1388 }
1389
1390 auto [new_handle, new_seq_id] = s.allocate();
1391 if (new_handle == INVALID_HANDLE) {
1392 return INVALID_HANDLE;
1393 }
1394
1395 BranchState* dst = s.get(new_handle);
1396 if (!dst) {
1397 s.release(new_handle);
1398 return INVALID_HANDLE;
1399 }
1400
1401 // Copy basic state
1402 dst->ctx = src->ctx;
1403 dst->model = src->model;
1404 dst->seq_id = new_seq_id;
1405 dst->position = src->position;
1406 dst->fork_head = src->position;
1407 dst->n_batch = src->n_batch;
1408 dst->n_vocab = src->n_vocab;
1409
1410#ifndef NDEBUG
1411 assert(kv::pos_max(src->ctx, new_seq_id) < 0 && "tenancy: acquired seq must be clean");
1412 assert(dst->parent == INVALID_HANDLE && dst->children.empty() && "fresh slot must have no topology");
1413#endif
1414
1415 // Fork KV cache
1416 kv::seq_cp(src->ctx, src->seq_id, new_seq_id);
1417
1418 // Record topology
1419 dst->parent = source;
1420 src->children.push_back(new_handle);
1421
1422 // Clone sampler chain
1423 if (src->sampler_chain != 0) {
1424 dst->sampler_chain = s.clone_sampler(src->sampler_chain);
1425 }
1426 dst->cached_params = src->cached_params;
1427
1428 if (src->grammar != 0) {
1429 dst->grammar = s.clone_grammar(src->grammar);
1430 }
1431
1432 if (src->boundary_tracker) {
1433 dst->boundary_tracker = src->boundary_tracker->clone().release();
1434 }
1435
1436 if (src->metrics != 0) {
1437 dst->metrics = s.clone_metrics(src->metrics);
1438 }
1439
1440 dst->last_token = src->last_token;
1441 dst->last_candidates = src->last_candidates;
1442
1443 // Logits snapshot: default preserves fork's "sample same distribution" contract.
1444 // opts.clone_logits=false skips the ~600KB memcpy from src for prefill-overwrite
1445 // consumers (rerank leaves, embedding probes). Buffer is still kept sized to
1446 // n_vocab so subsequent decode_scatter/prefill can write into it.
1447 if (opts.clone_logits) {
1448 dst->logits_snapshot = src->logits_snapshot;
1449 dst->has_logits = src->has_logits;
1450 } else {
1451 if (static_cast<int32_t>(dst->logits_snapshot.size()) != dst->n_vocab) {
1452 dst->logits_snapshot.resize(dst->n_vocab);
1453 }
1454 dst->has_logits = false;
1455 }
1456 dst->logit_bias = src->logit_bias;
1457
1458 dst->candidates_buffer.resize(dst->n_vocab);
1459
1460 LLOYAL_LOG_DEBUG("[branch::fork] Forked handle=%u -> handle=%u seq=%d->%d",
1461 source, new_handle, src->seq_id, new_seq_id);
1462
1463 return new_handle;
1464}
1465
1486inline void set_logit_bias(
1487 BranchHandle handle,
1488 const llama_logit_bias* biases,
1489 size_t n_biases,
1490 BranchStore& s) {
1491
1492
1493 BranchState* state = s.get(handle);
1494 if (!state) {
1495 throw std::runtime_error("set_logit_bias: invalid branch handle");
1496 }
1497
1498 // Replace existing biases with new set
1499 state->logit_bias.assign(biases, biases + n_biases);
1500
1501 LLOYAL_LOG_DEBUG("[branch::set_logit_bias] Set %zu biases on handle=%u",
1502 n_biases, handle);
1503}
1504
1513 BranchHandle handle,
1514 BranchStore& s) {
1515
1516
1517 BranchState* state = s.get(handle);
1518 if (!state) {
1519 throw std::runtime_error("clear_logit_bias: invalid branch handle");
1520 }
1521
1522 state->logit_bias.clear();
1523
1524 LLOYAL_LOG_DEBUG("[branch::clear_logit_bias] Cleared biases on handle=%u", handle);
1525}
1526
1553inline void set_steer(
1554 BranchHandle handle,
1555 std::function<void(llama_token_data_array&)> steer_fn,
1556 BranchStore& s) {
1557
1558
1559 BranchState* state = s.get(handle);
1560 if (!state) {
1561 throw std::runtime_error("set_steer: invalid branch handle");
1562 }
1563
1564 state->steer_fn = std::move(steer_fn);
1565
1566 LLOYAL_LOG_DEBUG("[branch::set_steer] Set steer callback on handle=%u", handle);
1567}
1568
1576inline void clear_steer(
1577 BranchHandle handle,
1578 BranchStore& s) {
1579
1580
1581 BranchState* state = s.get(handle);
1582 if (!state) {
1583 throw std::runtime_error("clear_steer: invalid branch handle");
1584 }
1585
1586 state->steer_fn = nullptr;
1587
1588 LLOYAL_LOG_DEBUG("[branch::clear_steer] Cleared steer callback on handle=%u", handle);
1589}
1590
1606template <SamplingParamsLike P>
1607inline void set_sampler_params(BranchHandle handle, const P& params, BranchStore& s) {
1608 BranchState* state = s.get(handle);
1609 if (!state) {
1610 throw std::runtime_error("set_sampler_params: invalid branch handle");
1611 }
1612
1613 CachedSamplingParams new_params = snapshot_params(params);
1614 if (new_params == state->cached_params && state->sampler_chain != 0) {
1615 return; // Memoized — no rebuild needed
1616 }
1617
1618 // Free old chain
1619 if (state->sampler_chain != 0) {
1620 s.free_sampler(state->sampler_chain);
1621 }
1622
1623 // Create new chain
1624 state->sampler_chain = s.create_sampler(params);
1625 state->cached_params = new_params;
1626
1627 LLOYAL_LOG_DEBUG("[branch::set_sampler_params] Rebuilt chain on handle=%u temp=%.3f",
1628 handle, new_params.temperature);
1629}
1630
1643inline void set_grammar(
1644 BranchHandle handle,
1645 const llama_model* model,
1646 const char* grammar_str,
1647 BranchStore& s) {
1648 BranchState* state = s.get(handle);
1649 if (!state) {
1650 throw std::runtime_error("set_grammar: invalid branch handle");
1651 }
1652
1653 // Free old grammar
1654 if (state->grammar != 0) {
1655 s.free_grammar(state->grammar);
1656 state->grammar = 0;
1657 }
1658
1659 // Create new grammar if provided
1660 if (grammar_str && grammar_str[0] != '\0') {
1661 state->grammar = s.create_grammar(model, grammar_str);
1662 }
1663
1664 LLOYAL_LOG_DEBUG("[branch::set_grammar] %s grammar on handle=%u",
1665 state->grammar != 0 ? "Set" : "Cleared", handle);
1666}
1667
1683 BranchHandle handle,
1684 const llama_model* model,
1685 const char* grammar_str,
1686 const std::vector<std::string>& trigger_patterns,
1687 const std::vector<llama_token>& trigger_tokens,
1688 BranchStore& s) {
1689 BranchState* state = s.get(handle);
1690 if (!state) {
1691 throw std::runtime_error("set_grammar_lazy: invalid branch handle");
1692 }
1693
1694 if (state->grammar != 0) {
1695 s.free_grammar(state->grammar);
1696 state->grammar = 0;
1697 }
1698
1699 if (grammar_str && grammar_str[0] != '\0') {
1700 state->grammar = s.create_grammar_lazy(
1701 model, grammar_str, trigger_patterns, trigger_tokens);
1702 }
1703
1704 LLOYAL_LOG_DEBUG("[branch::set_grammar_lazy] %s grammar on handle=%u",
1705 state->grammar != 0 ? "Set" : "Cleared", handle);
1706}
1707
1718inline void prune(BranchHandle handle, BranchStore& s) {
1719 BranchState* state = s.get(handle);
1720 if (!state) return;
1721 if (!state->children.empty())
1722 throw std::runtime_error("prune: RESTRICT — branch has children. Use pruneSubtree() for CASCADE.");
1723 s.release(handle);
1724}
1725
1736 std::vector<BranchHandle> stack{h}, post_order;
1737 while (!stack.empty()) {
1738 BranchHandle cur = stack.back(); stack.pop_back();
1739 post_order.push_back(cur);
1740 BranchState* st = s.get(cur);
1741 if (st) for (auto child : st->children) stack.push_back(child);
1742 }
1743 for (auto it = post_order.rbegin(); it != post_order.rend(); ++it)
1744 prune(*it, s);
1745}
1746
1767
1768
1769 BranchState* state = s.get(handle);
1770 if (!state) {
1771 throw std::runtime_error("force_snapshot_logits: invalid branch handle");
1772 }
1773
1774 // logits::get() throws if ctx is null or logits unavailable
1775 const float* raw_logits = logits::get(state->ctx, -1);
1776
1777 if (state->n_vocab <= 0) {
1778 throw std::runtime_error("force_snapshot_logits: invalid vocab size");
1779 }
1780
1781 std::memcpy(state->logits_snapshot.data(), raw_logits,
1782 state->n_vocab * sizeof(float));
1783 state->has_logits = true;
1784}
1785
1802inline void prefill(
1803 BranchHandle handle,
1804 const llama_token* tokens,
1805 size_t n_tokens,
1806 BranchStore& s) {
1807
1808
1809 BranchState* state = s.get(handle);
1810 if (!state) {
1811 throw std::runtime_error("prefill: invalid branch handle");
1812 }
1813
1814 // Pass raw pointer directly - no vector copy needed
1815 if (decode::many(state->ctx, tokens, static_cast<int32_t>(n_tokens),
1816 state->position, state->n_batch, state->seq_id) != 0) {
1817 throw std::runtime_error("prefill: llama_decode failed");
1818 }
1819
1820 state->position += static_cast<llama_pos>(n_tokens);
1821 s.add_cells_used(static_cast<uint32_t>(n_tokens));
1822
1823 // logits::get() throws if logits unavailable
1824 const float* raw_logits = logits::get(state->ctx, -1);
1825
1826 if (state->n_vocab <= 0) {
1827 throw std::runtime_error("prefill: invalid vocab size");
1828 }
1829
1830 std::memcpy(state->logits_snapshot.data(), raw_logits,
1831 state->n_vocab * sizeof(float));
1832 state->has_logits = true;
1833}
1834
1847inline void step(
1848 BranchHandle handle,
1849 llama_token token,
1850 BranchStore& s) {
1851
1852
1853 BranchState* state = s.get(handle);
1854 if (!state) {
1855 throw std::runtime_error("step: invalid branch handle");
1856 }
1857
1858 if (decode::one(state->ctx, token, state->position, state->seq_id, true) != 0) {
1859 throw std::runtime_error("step: llama_decode failed");
1860 }
1861 state->position += 1;
1862 s.add_cells_used(1);
1863
1864 // logits::get() throws if logits unavailable
1865 const float* raw_logits = logits::get(state->ctx, -1);
1866
1867 if (state->n_vocab <= 0) {
1868 throw std::runtime_error("step: invalid vocab size");
1869 }
1870
1871 std::memcpy(state->logits_snapshot.data(), raw_logits,
1872 state->n_vocab * sizeof(float));
1873 state->has_logits = true;
1874}
1875
1886inline const float* get_logits(BranchHandle handle, BranchStore& s) {
1887
1888
1889 const BranchState* state = s.get(handle);
1890 // Must check has_logits, not just empty() - buffer is pre-allocated in create()
1891 if (!state || !state->has_logits) {
1892 return nullptr;
1893 }
1894
1895 return state->logits_snapshot.data();
1896}
1897
1916inline void set_logits(BranchHandle handle, std::span<const float> data, BranchStore& s) {
1917 BranchState* state = s.get(handle);
1918 if (!state) {
1919 throw std::runtime_error("set_logits: invalid handle");
1920 }
1921 if (state->n_vocab <= 0) {
1922 throw std::runtime_error("set_logits: invalid vocab size");
1923 }
1924 if (static_cast<int32_t>(data.size()) != state->n_vocab) {
1925 throw std::runtime_error(
1926 "set_logits: data length (" + std::to_string(data.size()) +
1927 ") does not match branch n_vocab (" + std::to_string(state->n_vocab) + ")");
1928 }
1929
1930 std::memcpy(state->logits_snapshot.data(), data.data(),
1931 state->n_vocab * sizeof(float));
1932 state->has_logits = true;
1933}
1934
1955inline void get_logits_at(BranchHandle handle,
1956 std::span<const int32_t> indices,
1957 std::span<float> out,
1958 BranchStore& s) {
1959 if (indices.size() != out.size()) {
1960 throw std::runtime_error("get_logits_at: indices/out size mismatch");
1961 }
1962 const BranchState* state = s.get(handle);
1963 if (!state) {
1964 throw std::runtime_error("get_logits_at: invalid handle");
1965 }
1966 if (!state->has_logits) {
1967 throw std::runtime_error("get_logits_at: no captured logits for handle");
1968 }
1969 const int32_t n_vocab = state->n_vocab;
1970 const float* logits = state->logits_snapshot.data();
1971 for (size_t i = 0; i < indices.size(); ++i) {
1972 const int32_t idx = indices[i];
1973 // Bounds check: indices arrive from JS Int32Array via N-API — must validate
1974 // against n_vocab to prevent native OOB reads driven by JS-controlled offsets.
1975 if (idx < 0 || idx >= n_vocab) {
1976 throw std::runtime_error("get_logits_at: index out of range");
1977 }
1978 out[i] = logits[idx];
1979 }
1980}
1981
1996inline llama_token sample(BranchHandle handle, BranchStore& s) {
1997
1998
1999 BranchState* state = s.get(handle);
2000 llama_sampler* chain = state ? s.get_sampler_chain(state->sampler_chain) : nullptr;
2001 if (!state || !chain) {
2002 return -1;
2003 }
2004
2005 // Must have logits captured before sampling
2006 if (!state->has_logits) {
2007 LLOYAL_LOG_DEBUG("[branch::sample] No logits captured - call prefill()/step() first");
2008 return -1;
2009 }
2010
2011 // Reuse pre-allocated candidates buffer (avoids O(n_vocab) allocs per sample)
2012 for (int i = 0; i < state->n_vocab; i++) {
2013 state->candidates_buffer[i] = llama_token_data{
2014 static_cast<llama_token>(i),
2015 state->logits_snapshot[i],
2016 0.0f};
2017 }
2018
2019 llama_token_data_array cur_p = {
2020 state->candidates_buffer.data(),
2021 static_cast<size_t>(state->n_vocab),
2022 -1,
2023 false};
2024
2025 // Apply grammar first if present (via anti-corruption layer)
2026 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2027 if (gram) {
2028 grammar::apply(gram, &cur_p);
2029 }
2030
2031 // Apply logit bias if present — O(n_biases) via direct index
2032 // (candidates are in token-ID order; grammar::apply preserves order)
2033 if (!state->logit_bias.empty()) {
2034 for (const auto& bias : state->logit_bias) {
2035 if (bias.token >= 0 && bias.token < state->n_vocab) {
2036 cur_p.data[bias.token].logit += bias.bias;
2037 }
2038 }
2039 }
2040
2041 // Apply steer callback if present
2042 if (state->steer_fn) {
2043 try {
2044 state->steer_fn(cur_p);
2045 } catch (const std::exception& e) {
2046 LLOYAL_LOG_DEBUG("[branch::sample] Steer exception: %s", e.what());
2047 // Continue sampling without steer on exception
2048 }
2049 }
2050
2051 // Apply sampler chain (via anti-corruption layer)
2052 sampler::apply(chain, &cur_p);
2053
2054 if (cur_p.selected == -1) {
2055 return -1;
2056 }
2057
2058 llama_token token = cur_p.data[cur_p.selected].id;
2059
2060 // Capture filtered candidates for sampling metrics
2061 // After BOTH grammar and sampler chain - this is the actual sampling distribution
2062 state->last_token = token;
2063 state->last_candidates.clear();
2064 state->last_candidates.reserve(cur_p.size);
2065
2066 for (size_t i = 0; i < cur_p.size; i++) {
2067 state->last_candidates.push_back(cur_p.data[i]);
2068 }
2069
2070 return token;
2071}
2072
2088inline void accept_token(
2089 BranchHandle handle,
2090 llama_token token,
2091 BranchStore& s) {
2092
2093
2094 BranchState* state = s.get(handle);
2095 if (!state) return;
2096
2097 // Accept in grammar (via anti-corruption layer)
2098 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2099 if (gram) {
2100 grammar::accept(gram, token);
2101 }
2102
2103 // Accept in sampler chain for penalty tracking (via anti-corruption layer)
2104 llama_sampler* chain = s.get_sampler_chain(state->sampler_chain);
2105 if (chain) {
2106 sampler::accept(chain, token);
2107 }
2108
2109 // Update model-level perplexity (from raw logits)
2110 // Guard on has_logits to avoid computing surprisal from zero-filled buffer
2111 if (state->metrics != 0 && state->has_logits) {
2112 float ms = metrics::model_surprisal(
2113 state->logits_snapshot.data(), state->n_vocab, token);
2114 s.add_model_surprisal(state->metrics, ms);
2115 }
2116
2117 // Update sampling-level perplexity (from filtered candidates)
2118 if (state->metrics != 0 && !state->last_candidates.empty() &&
2119 token == state->last_token) {
2120 // Extract filtered logits and IDs from candidates
2121 std::vector<float> candidate_logits;
2122 std::vector<int32_t> candidate_ids;
2123 candidate_logits.reserve(state->last_candidates.size());
2124 candidate_ids.reserve(state->last_candidates.size());
2125
2126 for (const auto& cand : state->last_candidates) {
2127 candidate_logits.push_back(cand.logit);
2128 candidate_ids.push_back(cand.id);
2129 }
2130
2131 // Compute sampling-level surprisal
2132 float ss = metrics::sampling_surprisal(
2133 candidate_logits.data(),
2134 candidate_ids.data(),
2135 static_cast<int>(candidate_logits.size()),
2136 token
2137 );
2138 s.add_sampling_surprisal(state->metrics, ss);
2139 }
2140}
2141
2156inline void apply_grammar(
2157 BranchHandle handle,
2158 float* logits,
2159 int n_vocab,
2160 BranchStore& s) {
2161
2162
2163 BranchState* state = s.get(handle);
2164 llama_sampler* gram = state ? s.get_grammar_sampler(state->grammar) : nullptr;
2165 if (!state || !gram) return;
2166
2167 // Use pre-allocated candidates buffer if size matches, otherwise allocate
2168 std::vector<llama_token_data>* candidates_ptr;
2169 std::vector<llama_token_data> temp_buffer;
2170
2171 if (n_vocab == state->n_vocab && !state->candidates_buffer.empty()) {
2172 candidates_ptr = &state->candidates_buffer;
2173 } else {
2174 temp_buffer.resize(n_vocab);
2175 candidates_ptr = &temp_buffer;
2176 }
2177
2178 auto& candidates = *candidates_ptr;
2179 for (int i = 0; i < n_vocab; i++) {
2180 candidates[i] = llama_token_data{
2181 static_cast<llama_token>(i), logits[i], 0.0f};
2182 }
2183
2184 llama_token_data_array cur_p = {
2185 candidates.data(),
2186 static_cast<size_t>(n_vocab),
2187 -1,
2188 false};
2189
2190 grammar::apply(gram, &cur_p);
2191
2192 // Copy masked logits back
2193 for (int i = 0; i < n_vocab; i++) {
2194 logits[i] = candidates[i].logit;
2195 }
2196}
2197
2212inline std::vector<std::pair<llama_token, float>> get_legal_priors(
2213 BranchHandle handle,
2214 BranchStore& s) {
2215
2216
2217 BranchState* state = s.get(handle);
2218 if (!state || !state->has_logits) {
2219 return {};
2220 }
2221
2222 // Reuse pre-allocated candidates buffer
2223 for (int i = 0; i < state->n_vocab; i++) {
2224 state->candidates_buffer[i] = llama_token_data{
2225 static_cast<llama_token>(i),
2226 state->logits_snapshot[i],
2227 0.0f};
2228 }
2229
2230 llama_token_data_array cur_p = {
2231 state->candidates_buffer.data(),
2232 static_cast<size_t>(state->n_vocab),
2233 -1,
2234 false};
2235
2236 // Apply grammar to mask illegal tokens
2237 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2238 if (gram) {
2239 grammar::apply(gram, &cur_p);
2240 }
2241
2242 // Collect legal candidates (logit is finite after grammar masking)
2243 // Grammar masking sets illegal tokens to -INFINITY
2244 std::vector<std::pair<llama_token, float>> legal_priors;
2245 float max_logit = -std::numeric_limits<float>::infinity();
2246
2247 for (size_t i = 0; i < cur_p.size; i++) {
2248 if (std::isfinite(cur_p.data[i].logit)) { // Not masked
2249 legal_priors.emplace_back(cur_p.data[i].id, cur_p.data[i].logit);
2250 if (cur_p.data[i].logit > max_logit) {
2251 max_logit = cur_p.data[i].logit;
2252 }
2253 }
2254 }
2255
2256 if (legal_priors.empty()) {
2257 return {};
2258 }
2259
2260 // Compute softmax over legal moves only (numerically stable)
2261 float sum_exp = 0.0f;
2262 for (auto& [token, logit] : legal_priors) {
2263 float exp_val = std::exp(logit - max_logit);
2264 logit = exp_val; // Temporarily store exp value
2265 sum_exp += exp_val;
2266 }
2267
2268 // Normalize to probabilities
2269 for (auto& [token, prob] : legal_priors) {
2270 prob /= sum_exp;
2271 }
2272
2273 return legal_priors;
2274}
2275
2292
2293
2294 BranchState* state = s.get(handle);
2295 if (!state || !state->has_logits) {
2296 return -std::numeric_limits<float>::infinity();
2297 }
2298
2299 // Reuse pre-allocated candidates buffer
2300 for (int i = 0; i < state->n_vocab; i++) {
2301 state->candidates_buffer[i] = llama_token_data{
2302 static_cast<llama_token>(i),
2303 state->logits_snapshot[i],
2304 0.0f};
2305 }
2306
2307 llama_token_data_array cur_p = {
2308 state->candidates_buffer.data(),
2309 static_cast<size_t>(state->n_vocab),
2310 -1,
2311 false};
2312
2313 // Apply grammar to mask illegal tokens
2314 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2315 if (gram) {
2316 grammar::apply(gram, &cur_p);
2317 }
2318
2319 // Numerically stable logsumexp over legal tokens (finite logits only)
2320 float max_logit = -std::numeric_limits<float>::infinity();
2321 for (size_t i = 0; i < cur_p.size; i++) {
2322 if (std::isfinite(cur_p.data[i].logit) && cur_p.data[i].logit > max_logit) {
2323 max_logit = cur_p.data[i].logit;
2324 }
2325 }
2326
2327 if (!std::isfinite(max_logit)) {
2328 return -std::numeric_limits<float>::infinity(); // No legal tokens
2329 }
2330
2331 float sum_exp = 0.0f;
2332 for (size_t i = 0; i < cur_p.size; i++) {
2333 if (std::isfinite(cur_p.data[i].logit)) {
2334 sum_exp += std::exp(cur_p.data[i].logit - max_logit);
2335 }
2336 }
2337
2338 return max_logit + std::log(sum_exp);
2339}
2340
2352inline bool is_token_legal(
2353 BranchHandle handle,
2354 llama_token token,
2355 BranchStore& s) {
2356
2357
2358 BranchState* state = s.get(handle);
2359 if (!state || token < 0 || token >= state->n_vocab) {
2360 return false;
2361 }
2362
2363 // No grammar = all tokens legal
2364 llama_sampler* gram = s.get_grammar_sampler(state->grammar);
2365 if (!gram) {
2366 return true;
2367 }
2368
2369 // Build 1-element candidate array (stack allocated, no heap)
2370 llama_token_data single_candidate = {
2371 token,
2372 state->has_logits ? state->logits_snapshot[token] : 0.0f,
2373 0.0f
2374 };
2375
2376 llama_token_data_array cur_p = {
2377 &single_candidate,
2378 1,
2379 -1,
2380 false
2381 };
2382
2383 // Apply grammar - will set logit to -INFINITY if illegal
2384 grammar::apply(gram, &cur_p);
2385
2386 return std::isfinite(single_candidate.logit);
2387}
2388
2404 BranchHandle handle,
2405 llama_token token,
2406 float logsumexp,
2407 BranchStore& s) {
2408
2409
2410 BranchState* state = s.get(handle);
2411 if (!state || !state->has_logits || token < 0 || token >= state->n_vocab) {
2412 return 0.0f;
2413 }
2414
2415 float logit = state->logits_snapshot[token];
2416 return std::exp(logit - logsumexp);
2417}
2418
2434inline float get_token_prior(
2435 BranchHandle handle,
2436 llama_token token,
2437 float logsumexp,
2438 BranchStore& s) {
2439 if (!is_token_legal(handle, token, s)) {
2440 return 0.0f;
2441 }
2442 return get_token_prior_assume_legal(handle, token, logsumexp, s);
2443}
2444
2445// ===== STATE ACCESSORS =====
2446
2447
2454inline llama_pos get_position(BranchHandle handle, BranchStore& s) {
2455
2456 const BranchState* state = s.get(handle);
2457 return state ? state->position : -1;
2458}
2459
2466inline llama_pos get_fork_head(BranchHandle handle, BranchStore& s) {
2467 const BranchState* state = s.get(handle);
2468 return state ? state->fork_head : 0;
2469}
2470
2482inline float get_perplexity(BranchHandle handle, BranchStore& s) {
2483
2484 const BranchState* state = s.get(handle);
2485 if (!state || state->metrics == 0) {
2486 return std::numeric_limits<float>::infinity();
2487 }
2488 return s.get_model_ppl(state->metrics);
2489}
2490
2503
2504 const BranchState* state = s.get(handle);
2505 if (!state || state->metrics == 0) {
2506 return std::numeric_limits<float>::infinity();
2507 }
2508 return s.get_sampling_ppl(state->metrics);
2509}
2510
2522
2523 const BranchState* state = s.get(handle);
2524
2525 if (!state || state->last_candidates.empty() || state->last_token < 0) {
2526 return 0.0f;
2527 }
2528
2529 // Extract candidates
2530 std::vector<float> candidate_logits;
2531 std::vector<int32_t> candidate_ids;
2532 candidate_logits.reserve(state->last_candidates.size());
2533 candidate_ids.reserve(state->last_candidates.size());
2534
2535 for (const auto& cand : state->last_candidates) {
2536 candidate_logits.push_back(cand.logit);
2537 candidate_ids.push_back(cand.id);
2538 }
2539
2540 // Compute surprisal from filtered distribution
2541 float surprisal = metrics::sampling_surprisal(
2542 candidate_logits.data(),
2543 candidate_ids.data(),
2544 static_cast<int>(candidate_logits.size()),
2545 state->last_token
2546 );
2547
2548 // Convert to probability: P = exp(-surprisal)
2549 return std::exp(-surprisal);
2550}
2551
2558inline int get_n_vocab(BranchHandle handle, BranchStore& s) {
2559
2560 const BranchState* state = s.get(handle);
2561 return state ? state->n_vocab : 0;
2562}
2563
2564
2565// ===== RAII WRAPPER =====
2566
2584class Branch {
2585public:
2586 Branch() : store_(nullptr), handle_(INVALID_HANDLE) {}
2587
2589 : store_(store), handle_(handle) {}
2590
2593 if (handle_ != INVALID_HANDLE && store_) {
2594 branch::pruneSubtree(handle_, *store_);
2595 }
2596 }
2597
2598 Branch(Branch&& other) noexcept
2599 : store_(other.store_), handle_(other.handle_) {
2600 other.handle_ = INVALID_HANDLE;
2601 }
2602
2603 Branch& operator=(Branch&& other) noexcept {
2604 if (this != &other) {
2605 if (handle_ != INVALID_HANDLE && store_) {
2606 branch::pruneSubtree(handle_, *store_);
2607 }
2608 store_ = other.store_;
2609 handle_ = other.handle_;
2610 other.handle_ = INVALID_HANDLE;
2611 }
2612 return *this;
2613 }
2614
2615 Branch(const Branch&) = delete;
2616 Branch& operator=(const Branch&) = delete;
2617
2619 template <SamplingParamsLike P>
2621 llama_context* ctx,
2622 const llama_model* model,
2623 BranchStore& store,
2624 llama_pos start_pos,
2625 const P& params,
2626 int n_batch = DEFAULT_N_BATCH,
2627 const char* grammar_str = nullptr,
2628 boundaries::BoundaryTracker* boundary_tracker = nullptr) {
2629 BranchHandle h = branch::create(ctx, model, store, start_pos, params, n_batch, grammar_str, boundary_tracker);
2630 return Branch(&store, h);
2631 }
2632
2635 BranchHandle h = branch::fork(handle_, *store_);
2636 return Branch(store_, h);
2637 }
2638
2640 void prune() {
2641 branch::prune(handle_, *store_);
2642 handle_ = INVALID_HANDLE;
2643 }
2644
2647 branch::pruneSubtree(handle_, *store_);
2648 handle_ = INVALID_HANDLE;
2649 }
2650
2654 branch::force_snapshot_logits(handle_, *store_);
2655 }
2656
2660 void prefill(const llama_token* tokens, size_t n) {
2661 branch::prefill(handle_, tokens, n, *store_);
2662 }
2663
2666 void step(llama_token token) {
2667 branch::step(handle_, token, *store_);
2668 }
2669
2672 const float* logits() const {
2673 return branch::get_logits(handle_, *store_);
2674 }
2675
2679 llama_token sample() {
2680 return branch::sample(handle_, *store_);
2681 }
2682
2685 void accept(llama_token token) {
2686 branch::accept_token(handle_, token, *store_);
2687 }
2688
2690 bool is_eog(llama_token token) const {
2691 const BranchState* st = store_ ? store_->get(handle_) : nullptr;
2692 return st && st->model ? tokenizer::is_eog(st->model, token) : false;
2693 }
2694
2697 template <SamplingParamsLike P>
2698 void setSamplerParams(const P& params) {
2699 branch::set_sampler_params(handle_, params, *store_);
2700 }
2701
2704 void setGrammar(const char* grammar_str) {
2705 const BranchState* st = store_ ? store_->get(handle_) : nullptr;
2706 branch::set_grammar(handle_, st ? st->model : nullptr, grammar_str, *store_);
2707 }
2708
2709 // ===== ACCESSORS =====
2710
2712 llama_pos position() const { return branch::get_position(handle_, *store_); }
2714 llama_pos forkHead() const { return branch::get_fork_head(handle_, *store_); }
2716 float perplexity() const { return branch::get_perplexity(handle_, *store_); }
2718 int n_vocab() const { return branch::get_n_vocab(handle_, *store_); }
2720 bool valid() const { return handle_ != INVALID_HANDLE; }
2722 BranchHandle handle() const { return handle_; }
2723
2724 // ===== TOPOLOGY =====
2725
2727 BranchHandle parentHandle() const { return store_ ? store_->parent(handle_) : INVALID_HANDLE; }
2729 const std::vector<BranchHandle>& childHandles() const {
2730 static const std::vector<BranchHandle> empty;
2731 return store_ ? store_->children(handle_) : empty;
2732 }
2734 bool isLeaf() const { return store_ ? store_->isLeaf(handle_) : true; }
2736 bool isActive() const { return store_ ? store_->isActive(handle_) : false; }
2737
2738private:
2739 BranchStore* store_;
2740 BranchHandle handle_;
2741};
2742
2743} // namespace lloyal::branch
Stub BoundaryTracker - does nothing.
Handle table and batched decode orchestrator for branch management.
Definition branch.hpp:392
GrammarHandle create_grammar_lazy(const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const char *root="root")
Create a lazy grammar (unconstrained until trigger fires)
Definition branch.hpp:757
float get_sampling_ppl(MetricsHandle h) const
Get sampling-level perplexity from a metrics tracker.
Definition branch.hpp:888
bool isActive(BranchHandle h) const
Test whether a branch holds a KV lease.
Definition branch.hpp:626
void free_grammar(GrammarHandle h)
Free a grammar.
Definition branch.hpp:792
GrammarHandle create_grammar(const llama_model *model, const char *grammar_str, const char *root="root")
Create a grammar sampler and register it.
Definition branch.hpp:738
MetricsHandle clone_metrics(MetricsHandle h)
Clone a metrics tracker (for fork)
Definition branch.hpp:824
size_t available() const
Number of vacant seq_ids available for acquisition.
Definition branch.hpp:561
void release(BranchHandle handle)
Release a branch slot + evict its KV lease.
Definition branch.hpp:463
void drain()
Explicit teardown — evict all leases while context is alive.
Definition branch.hpp:512
float get_model_ppl(MetricsHandle h) const
Get model-level perplexity from a metrics tracker.
Definition branch.hpp:874
bool isLeaf(BranchHandle h) const
Test whether a branch is a leaf (no children)
Definition branch.hpp:616
void decode_each(std::span< const DecodeEachItem > items)
Decode one token per branch in a single GPU dispatch.
Definition branch.hpp:921
bool sampler_has_dist(SamplerChainHandle h) const
Check if a sampler chain ends with dist (stochastic) or greedy.
Definition branch.hpp:723
MetricsHandle create_metrics()
Create a metrics tracker and register it.
Definition branch.hpp:813
~BranchStore()
Destructor — frees CPU resources.
Definition branch.hpp:417
SamplerChainHandle clone_sampler(SamplerChainHandle h)
Clone a sampler chain (for fork)
Definition branch.hpp:687
llama_pos fork_head(BranchHandle h) const
Get a branch's fork head (parent position at fork time)
Definition branch.hpp:595
Allocation allocate()
Allocate a branch slot + KV lease atomically.
Definition branch.hpp:439
void add_cells_used(uint32_t n)
Increment cells_used counter (for standalone prefill/step outside BranchStore methods)
Definition branch.hpp:578
const BranchState * get(BranchHandle handle) const
Look up branch state by handle.
Definition branch.hpp:660
void free_sampler(SamplerChainHandle h)
Free a sampler chain.
Definition branch.hpp:703
void retainOnly(BranchHandle winner)
Keep only the winner — nuclear KV + CPU cleanup.
Definition branch.hpp:534
KvPressure kv_pressure() const
KV cache pressure snapshot — O(1), no tree walking.
Definition branch.hpp:570
SamplerChainHandle create_sampler(const P &params)
Create a sampler chain and register it.
Definition branch.hpp:672
llama_sampler * get_grammar_sampler(GrammarHandle h) const
Dereference a grammar handle (non-owning)
Definition branch.hpp:801
void add_sampling_surprisal(MetricsHandle h, float surprisal)
Add sampling-level surprisal to a metrics tracker.
Definition branch.hpp:860
void init_tenancy(llama_context *ctx)
Initialize KV tenancy after context creation.
Definition branch.hpp:499
void merge_logits(BranchHandle dst_handle, std::span< const BranchHandle > expert_handles, float alpha)
Merge experts' logits_snapshot into dst's logits_snapshot.
Definition branch.hpp:1102
GrammarHandle clone_grammar(GrammarHandle h)
Clone a grammar (for fork)
Definition branch.hpp:777
void decode_scatter(std::span< const DecodeScatterItem > items)
Decode variable token counts per branch with auto-chunking.
Definition branch.hpp:993
BranchStore(size_t initial_capacity=16)
Construct a branch store with initial slot capacity.
Definition branch.hpp:398
void free_metrics(MetricsHandle h)
Free a metrics tracker.
Definition branch.hpp:837
const std::vector< BranchHandle > & children(BranchHandle h) const
Get a branch's child handles.
Definition branch.hpp:605
BranchState * get(BranchHandle handle)
Look up branch state by handle.
Definition branch.hpp:640
BranchHandle parent(BranchHandle h) const
Get a branch's parent handle.
Definition branch.hpp:585
void add_model_surprisal(MetricsHandle h, float surprisal)
Add model-level surprisal to a metrics tracker.
Definition branch.hpp:846
llama_sampler * get_sampler_chain(SamplerChainHandle h) const
Dereference a sampler chain handle (non-owning)
Definition branch.hpp:712
~Branch()
Destructor — CASCADE prunes entire subtree.
Definition branch.hpp:2592
void accept(llama_token token)
Accept a token — advance grammar, penalty window, and metrics.
Definition branch.hpp:2685
Branch & operator=(const Branch &)=delete
int n_vocab() const
Vocabulary size.
Definition branch.hpp:2718
void setGrammar(const char *grammar_str)
Replace grammar constraint (nullptr/empty to remove)
Definition branch.hpp:2704
const std::vector< BranchHandle > & childHandles() const
Child branch handles (empty if leaf)
Definition branch.hpp:2729
void setSamplerParams(const P &params)
Replace sampler chain with new parameters (memoized)
Definition branch.hpp:2698
BranchHandle handle() const
Underlying opaque handle (for interop with free functions)
Definition branch.hpp:2722
const float * logits() const
Get the branch's captured logits snapshot.
Definition branch.hpp:2672
bool is_eog(llama_token token) const
Check if a token is end-of-generation for this branch's model.
Definition branch.hpp:2690
BranchHandle parentHandle() const
Parent branch handle, or INVALID_HANDLE if root.
Definition branch.hpp:2727
Branch(BranchStore *store, BranchHandle handle)
Definition branch.hpp:2588
llama_pos position() const
Current decode position (token count)
Definition branch.hpp:2712
llama_token sample()
Sample a token from captured logits.
Definition branch.hpp:2679
void step(llama_token token)
Decode one token and capture logits (generation step)
Definition branch.hpp:2666
void pruneSubtree()
CASCADE prune — removes entire subtree.
Definition branch.hpp:2646
float perplexity() const
Model-level perplexity (from raw logits, pre-filter)
Definition branch.hpp:2716
Branch(const Branch &)=delete
bool isActive() const
True if this branch holds a KV lease.
Definition branch.hpp:2736
llama_pos forkHead() const
Parent's position at fork time (0 for root branches)
Definition branch.hpp:2714
void force_snapshot_logits()
Force-copy shared logits buffer into this branch's snapshot.
Definition branch.hpp:2653
bool isLeaf() const
True if this branch has no children.
Definition branch.hpp:2734
void prefill(const llama_token *tokens, size_t n)
Decode multiple tokens and capture logits atomically (prompt injection)
Definition branch.hpp:2660
bool valid() const
True if this Branch holds a valid handle.
Definition branch.hpp:2720
void prune()
RESTRICT prune (throws if children exist)
Definition branch.hpp:2640
Branch fork()
Fork: allocates slot + lease, records topology edge.
Definition branch.hpp:2634
Branch & operator=(Branch &&other) noexcept
Definition branch.hpp:2603
static Branch create(llama_context *ctx, const llama_model *model, BranchStore &store, llama_pos start_pos, const P &params, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Factory: allocates slot + lease from store.
Definition branch.hpp:2620
Branch(Branch &&other) noexcept
Definition branch.hpp:2598
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:48
Batch Decoding Operations.
Grammar-Constrained Sampling.
constexpr llama_seq_id NO_LEASE
Sentinel value indicating a branch has no KV residency.
Definition kv.hpp:207
KV Cache Physics.
Zero-copy logits access with clear lifetime semantics.
Distribution Metrics for Test-Time Alignment.
BranchHandle fork(BranchHandle source, BranchStore &s, ForkOpts opts={})
Fork a branch into a new independent sequence.
Definition branch.hpp:1383
void prefill(BranchHandle handle, const llama_token *tokens, size_t n_tokens, BranchStore &s)
Decode multiple tokens and capture logits atomically (prompt prefill)
Definition branch.hpp:1802
void set_logit_bias(BranchHandle handle, const llama_logit_bias *biases, size_t n_biases, BranchStore &s)
Definition branch.hpp:1486
float get_perplexity(BranchHandle handle, BranchStore &s)
Get model-level perplexity (from raw logits)
Definition branch.hpp:2482
void apply_grammar(BranchHandle handle, float *logits, int n_vocab, BranchStore &s)
Apply grammar constraints to an external logits buffer.
Definition branch.hpp:2156
int32_t GrammarHandle
Handle to a grammar sampler in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:154
uint32_t BranchHandle
Opaque handle to a branch slot.
Definition branch.hpp:95
void step(BranchHandle handle, llama_token token, BranchStore &s)
Decode a single token and capture logits (generation step)
Definition branch.hpp:1847
constexpr int DEFAULT_N_BATCH
Default batch size for decode operations.
Definition branch.hpp:99
void set_grammar_lazy(BranchHandle handle, const llama_model *model, const char *grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, BranchStore &s)
Set lazy grammar on a branch (unconstrained until trigger fires)
Definition branch.hpp:1682
void set_logits(BranchHandle handle, std::span< const float > data, BranchStore &s)
Overwrite a branch's logits_snapshot with caller-provided values.
Definition branch.hpp:1916
void set_steer(BranchHandle handle, std::function< void(llama_token_data_array &)> steer_fn, BranchStore &s)
Definition branch.hpp:1553
constexpr uint32_t GEN_SHIFT
Bit shift for generation field.
Definition branch.hpp:100
void accept_token(BranchHandle handle, llama_token token, BranchStore &s)
Accept a sampled token, advancing grammar and sampler state.
Definition branch.hpp:2088
void force_snapshot_logits(BranchHandle handle, BranchStore &s)
Force-copy the shared llama.cpp logits buffer into this branch's private snapshot.
Definition branch.hpp:1766
llama_pos get_fork_head(BranchHandle handle, BranchStore &s)
Get the branch's fork head (parent position at fork time)
Definition branch.hpp:2466
float get_sampling_perplexity(BranchHandle handle, BranchStore &s)
Get sampling-level perplexity (from filtered distribution)
Definition branch.hpp:2502
float get_legal_logsumexp(BranchHandle handle, BranchStore &s)
Compute log-sum-exp over grammar-legal logits.
Definition branch.hpp:2291
void clear_logit_bias(BranchHandle handle, BranchStore &s)
Clear all logit biases from a branch.
Definition branch.hpp:1512
uint16_t handle_generation(BranchHandle h)
Extract generation counter from a branch handle.
Definition branch.hpp:134
const float * get_logits(BranchHandle handle, BranchStore &s)
Get the branch's captured logits snapshot.
Definition branch.hpp:1886
void set_sampler_params(BranchHandle handle, const P &params, BranchStore &s)
Replace a branch's sampler chain with new parameters.
Definition branch.hpp:1607
void prune(BranchHandle handle, BranchStore &s)
Prune a leaf branch (RESTRICT — throws if children exist)
Definition branch.hpp:1718
BranchHandle create(llama_context *ctx, const llama_model *model, BranchStore &s, llama_pos start_pos, const P &params, int n_batch=DEFAULT_N_BATCH, const char *grammar_str=nullptr, boundaries::BoundaryTracker *boundary_tracker=nullptr)
Create a new branch with sampler chain, optional grammar, and metrics.
Definition branch.hpp:1290
CachedSamplingParams snapshot_params(const P &p)
Snapshot sampling params for memoization comparison.
Definition branch.hpp:237
constexpr uint32_t INDEX_MASK
Mask for slot index field.
Definition branch.hpp:101
int get_n_vocab(BranchHandle handle, BranchStore &s)
Get the branch's vocabulary size.
Definition branch.hpp:2558
llama_token sample(BranchHandle handle, BranchStore &s)
Sample a token from the branch's captured logits.
Definition branch.hpp:1996
void pruneSubtree(BranchHandle h, BranchStore &s)
Prune a branch and all descendants (CASCADE — iterative post-order)
Definition branch.hpp:1735
constexpr BranchHandle INVALID_HANDLE
Null handle sentinel.
Definition branch.hpp:97
float get_token_prior(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token, checking grammar legality first.
Definition branch.hpp:2434
constexpr llama_seq_id NO_LEASE
Branch has no KV residency.
Definition branch.hpp:98
int32_t MetricsHandle
Handle to a metrics tracker in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:157
llama_pos get_position(BranchHandle handle, BranchStore &s)
Get the branch's current decode position.
Definition branch.hpp:2454
void set_grammar(BranchHandle handle, const llama_model *model, const char *grammar_str, BranchStore &s)
Replace a branch's grammar constraint.
Definition branch.hpp:1643
void clear_steer(BranchHandle handle, BranchStore &s)
Clear the steer callback from a branch.
Definition branch.hpp:1576
float get_token_prior_assume_legal(BranchHandle handle, llama_token token, float logsumexp, BranchStore &s)
Compute prior probability for a token known to be grammar-legal.
Definition branch.hpp:2403
BranchHandle make_handle(uint16_t index, uint16_t generation)
Construct a branch handle from index and generation.
Definition branch.hpp:144
uint16_t handle_index(BranchHandle h)
Extract slot index from a branch handle.
Definition branch.hpp:125
int32_t SamplerChainHandle
Handle to a sampler chain in BranchStore's registry (0 = invalid/none)
Definition branch.hpp:151
std::vector< std::pair< llama_token, float > > get_legal_priors(BranchHandle handle, BranchStore &s)
Get grammar-legal tokens with renormalized probabilities.
Definition branch.hpp:2212
bool is_token_legal(BranchHandle handle, llama_token token, BranchStore &s)
Check if a token is legal under grammar constraints.
Definition branch.hpp:2352
float get_last_sampling_prior(BranchHandle handle, BranchStore &s)
Get the last sampled token's prior from the filtered distribution.
Definition branch.hpp:2521
void get_logits_at(BranchHandle handle, std::span< const int32_t > indices, std::span< float > out, BranchStore &s)
Read selected logit values from a branch's snapshot.
Definition branch.hpp:1955
std::vector< PackedChunk > bin_pack(const std::span< const llama_token > *items, int32_t n, int32_t n_batch)
Greedy first-fit bin-packing of token spans into n_batch-sized chunks.
Definition decode.hpp:481
int one(llama_context *ctx, llama_token tok, llama_pos pos, llama_seq_id seq_id=0, bool want_logits=true)
Decode a single token into the KV cache.
Definition decode.hpp:239
int many(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_past, int32_t n_batch, llama_seq_id seq_id=0)
Decode multiple tokens into the KV cache with auto-chunking.
Definition decode.hpp:125
int scatter(llama_context *ctx, const ScatterItem *items, int32_t n, Scratch &scratch)
Decode multiple tokens per sequence in a single llama_decode() call.
Definition decode.hpp:396
int each(llama_context *ctx, const EachItem *items, int32_t n, Scratch &scratch)
Decode one token per sequence in a single llama_decode() call.
Definition decode.hpp:340
constexpr T as_value(const X &x, T def)
Extract value from either T or std::optional<T> with fallback.
Definition sampler.hpp:51
void free_sampler(llama_sampler *smpl)
Free a grammar sampler.
Definition grammar.hpp:235
llama_sampler * clone_sampler(llama_sampler *smpl)
Clone a grammar sampler (for fork/branching).
Definition grammar.hpp:213
llama_sampler * init_sampler(const llama_model *model, const std::string &grammar_str, const std::string &root_rule="root")
Initialize a grammar sampler from GBNF grammar string.
Definition grammar.hpp:107
llama_sampler * init_lazy_sampler(const llama_model *model, const std::string &grammar_str, const std::vector< std::string > &trigger_patterns, const std::vector< llama_token > &trigger_tokens, const std::string &root_rule="root")
Initialize a lazy grammar sampler from GBNF grammar string.
Definition grammar.hpp:158
void accept(llama_sampler *smpl, llama_token token)
Accept a token into grammar state.
Definition grammar.hpp:263
void apply(llama_sampler *smpl, llama_token_data_array *cur_p)
Apply grammar constraint to candidates.
Definition grammar.hpp:249
void evict_all(State &s)
Evict every leased seq_id.
Definition kv.hpp:351
llama_seq_id acquire(State &s)
Acquire a seq_id from the vacant pool.
Definition kv.hpp:256
size_t available(const State &s)
Number of vacant seq_ids available for acquisition.
Definition kv.hpp:365
void evict(State &s, llama_seq_id seq)
Evict a seq_id — strip all KV tags then release.
Definition kv.hpp:297
void retain(State &s, llama_seq_id keep)
Nuclear retain — keep one seq, rebuild vacancy from scratch.
Definition kv.hpp:320
State init(llama_context *ctx, llama_seq_id n_seq_max)
Initialize tenancy with all seq_ids vacant.
Definition kv.hpp:234
void release(State &s, llama_seq_id seq)
Release a seq_id back to vacant — bookkeeping only, no KV calls.
Definition kv.hpp:275
void seq_cp(llama_context *ctx, llama_seq_id src, llama_seq_id dst, llama_pos p0=0, llama_pos p1=-1)
Copy KV cache from one sequence to another.
Definition kv.hpp:138
llama_pos pos_max(llama_context *ctx, llama_seq_id seq)
Get maximum position in KV cache sequence.
Definition kv.hpp:111
float * get(llama_context *ctx, int32_t index=-1)
Definition logits.hpp:79
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
Definition metrics.hpp:227
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
Definition metrics.hpp:132
void apply(llama_sampler *chain, llama_token_data_array *cur_p)
Apply a sampler chain to a candidate array.
Definition sampler.hpp:582
void accept(llama_sampler *chain, llama_token token)
Accept a token into the sampler chain.
Definition sampler.hpp:596
llama_sampler * clone_chain(llama_sampler *chain)
Clone a sampler chain.
Definition sampler.hpp:527
llama_sampler * create_chain(const P &params)
Create a persistent sampler chain from parameters.
Definition sampler.hpp:466
void free_chain(llama_sampler *chain)
Free a sampler chain.
Definition sampler.hpp:568
bool is_eog(const llama_vocab *vocab, llama_token token)
Check if token is end-of-generation marker.
Token Sampling Operations.
Consolidated mutable state for a single branch.
Definition branch.hpp:275
bool in_use
True when slot is allocated to an active branch.
Definition branch.hpp:315
const llama_model * model
Llama model (not owned, must outlive branch)
Definition branch.hpp:277
bool has_logits
True only after force_snapshot_logits(), prefill(), or step()
Definition branch.hpp:299
std::function< void(llama_token_data_array &)> steer_fn
Dynamic logit callback, NOT cloned on fork.
Definition branch.hpp:291
int n_batch
Batch size for decode operations.
Definition branch.hpp:311
llama_seq_id seq_id
KV cache sequence identifier (NO_LEASE when inactive)
Definition branch.hpp:279
llama_pos position
Current decode position in the sequence.
Definition branch.hpp:280
std::vector< llama_logit_bias > logit_bias
Static token biases, cloned on fork.
Definition branch.hpp:290
std::vector< llama_token_data > last_candidates
Filtered candidates from last sample()
Definition branch.hpp:296
std::vector< float > logits_snapshot
Captured logit distribution (n_vocab floats)
Definition branch.hpp:298
std::vector< llama_token_data > candidates_buffer
Reusable scratch buffer for sampling (avoids O(n_vocab) allocs per sample call).
Definition branch.hpp:309
boundaries::BoundaryTracker * boundary_tracker
Token boundary detector (owned, optional)
Definition branch.hpp:288
SamplerChainHandle sampler_chain
Handle into BranchStore's sampler registry.
Definition branch.hpp:283
llama_context * ctx
Llama context (not owned, must outlive branch)
Definition branch.hpp:276
BranchHandle parent
Parent branch (INVALID_HANDLE if root)
Definition branch.hpp:318
int n_vocab
Vocabulary size (cached for buffer pre-allocation)
Definition branch.hpp:312
GrammarHandle grammar
Handle into BranchStore's grammar registry.
Definition branch.hpp:284
std::vector< BranchHandle > children
Child branches forked from this one.
Definition branch.hpp:319
llama_token last_token
Last token returned by sample()
Definition branch.hpp:295
MetricsHandle metrics
Handle into BranchStore's metrics registry.
Definition branch.hpp:293
llama_pos fork_head
Parent's position at fork time (0 for root branches)
Definition branch.hpp:281
CachedSamplingParams cached_params
Params used to create current chain (for memoization)
Definition branch.hpp:286
uint16_t generation
Slot generation counter (for ABA prevention)
Definition branch.hpp:314
Result of allocate(): a slot handle + its leased seq_id.
Definition branch.hpp:429
Concrete sampling params snapshot for memoization.
Definition branch.hpp:215
bool operator==(const CachedSamplingParams &) const =default
Item for decode_each: one token per branch.
Definition branch.hpp:330
Item for decode_scatter: variable tokens per branch.
Definition branch.hpp:348
std::span< const llama_token > tokens
Definition branch.hpp:350
Options controlling what state fork() clones from source to child.
Definition branch.hpp:1352
bool clone_logits
If true (default), copy src->logits_snapshot to child (~n_vocab*4 bytes, ~600KB for 150k-vocab models...
Definition branch.hpp:1358
RAII entry for a grammar sampler in the registry.
Definition branch.hpp:190
GrammarEntry(GrammarEntry &&o) noexcept
Definition branch.hpp:196
GrammarEntry & operator=(const GrammarEntry &)=delete
GrammarEntry & operator=(GrammarEntry &&o) noexcept
Definition branch.hpp:197
GrammarEntry(const GrammarEntry &)=delete
llama_sampler * sampler
Definition branch.hpp:191
Snapshot of KV cache pressure from BranchStore.
Definition branch.hpp:114
uint32_t remaining
n_ctx - cells_used (clamped to 0)
Definition branch.hpp:117
uint32_t n_ctx
Total KV capacity.
Definition branch.hpp:115
uint32_t cells_used
Cells allocated since last reset.
Definition branch.hpp:116
RAII entry for a sampler chain in the registry.
Definition branch.hpp:165
SamplerChainEntry(const SamplerChainEntry &)=delete
SamplerChainEntry & operator=(SamplerChainEntry &&o) noexcept
Definition branch.hpp:174
SamplerChainEntry(SamplerChainEntry &&o) noexcept
Definition branch.hpp:172
bool has_dist
True if chain ends with dist (temp > 0), false if greedy.
Definition branch.hpp:167
SamplerChainEntry & operator=(const SamplerChainEntry &)=delete
llama_context * ctx
Context for KV operations (nullptr after drain)
Definition kv.hpp:218
Unified model + sampling perplexity tracker.
Definition metrics.hpp:100