liblloyal 1.0.0
Composable primitives for llama.cpp inference
Loading...
Searching...
No Matches
metrics.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
30#include <cmath>
31#include <cstdint>
32#include <limits>
33#include <unordered_map>
34#include <vector>
35
36namespace lloyal::metrics {
37
38// ============================================================================
39// Types
40// ============================================================================
41
42enum class Base { Nats, Bits };
43
44using PerplexityHandle = int32_t;
45
46// ============================================================================
47// Internal helpers (ported from metrics.ts)
48// ============================================================================
49
50namespace detail {
51
52constexpr float LN2 = 0.693147180559945309417232121458176568f;
53
58inline float max_finite(const float* a, int n) {
59 float m = -std::numeric_limits<float>::infinity();
60 for (int i = 0; i < n; ++i) {
61 const float v = a[i];
62 if (std::isfinite(v) && v > m) m = v;
63 }
64 return m;
65}
66
76inline float log_sum_exp(const float* a, int n, float shift) {
77 float s = 0.0f;
78 for (int i = 0; i < n; ++i) {
79 const float v = a[i];
80 if (std::isfinite(v)) s += std::exp(v - shift);
81 }
82 if (s == 0.0f) return -std::numeric_limits<float>::infinity();
83 return shift + std::log(s);
84}
85
86// Perplexity state for handle-based tracking
88 float nll_sum_nats = 0.0f;
89 int count = 0;
90};
91
92inline std::unordered_map<PerplexityHandle, PerplexityState>& get_registry() {
93 static std::unordered_map<PerplexityHandle, PerplexityState> registry;
94 return registry;
95}
96
98 static PerplexityHandle next = 1;
99 return next;
100}
101
102} // namespace detail
103
104// ============================================================================
105// Model-level metrics (raw logits, before filters)
106// ============================================================================
107
131inline float model_surprisal(
132 const float* logits,
133 int n_vocab,
134 int picked_id,
135 Base base = Base::Nats
136) {
137 if (!logits || n_vocab == 0) {
138 return std::numeric_limits<float>::infinity();
139 }
140 if (picked_id < 0 || picked_id >= n_vocab) {
141 return std::numeric_limits<float>::infinity();
142 }
143
144 const float picked = logits[picked_id];
145 if (!std::isfinite(picked)) return std::numeric_limits<float>::infinity();
146
147 const float m = detail::max_finite(logits, n_vocab);
148 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
149
150 const float log_z = detail::log_sum_exp(logits, n_vocab, m);
151 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
152
153 const float surprisal_nats = std::max(0.0f, -(picked - log_z));
154 return base == Base::Bits ? surprisal_nats / detail::LN2 : surprisal_nats;
155}
156
180inline float model_entropy(
181 const float* logits,
182 int n_vocab,
183 Base base = Base::Nats
184) {
185 if (!logits || n_vocab == 0) {
186 return std::numeric_limits<float>::infinity();
187 }
188
189 const float m = detail::max_finite(logits, n_vocab);
190 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
191
192 const float log_z = detail::log_sum_exp(logits, n_vocab, m);
193 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
194
195 float ez = 0.0f;
196 for (int i = 0; i < n_vocab; ++i) {
197 const float z = logits[i];
198 if (!std::isfinite(z)) continue;
199 const float p = std::exp(z - log_z);
200 ez += p * z;
201 }
202
203 const float h_nats = std::max(0.0f, log_z - ez);
204 return base == Base::Bits ? h_nats / detail::LN2 : h_nats;
205}
206
207// ============================================================================
208// Sampling-level metrics (post-filter logits, after top-k/p/temp)
209// ============================================================================
210
227 const float* candidate_logits,
228 const int32_t* candidate_ids,
229 int n_candidates,
230 int picked_id,
231 Base base = Base::Nats
232) {
233 if (!candidate_logits || !candidate_ids || n_candidates == 0) {
234 return std::numeric_limits<float>::infinity();
235 }
236
237 // Find picked_id in candidates
238 int local = -1;
239 for (int i = 0; i < n_candidates; ++i) {
240 if (candidate_ids[i] == picked_id) {
241 local = i;
242 break;
243 }
244 }
245 if (local == -1) return std::numeric_limits<float>::infinity();
246 if (n_candidates == 1) return 0.0f;
247
248 const float picked = candidate_logits[local];
249 if (!std::isfinite(picked)) return std::numeric_limits<float>::infinity();
250
251 const float m = detail::max_finite(candidate_logits, n_candidates);
252 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
253
254 const float log_z = detail::log_sum_exp(candidate_logits, n_candidates, m);
255 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
256
257 const float surprisal_nats = std::max(0.0f, -(picked - log_z));
258 return base == Base::Bits ? surprisal_nats / detail::LN2 : surprisal_nats;
259}
260
272inline float sampling_entropy(
273 const float* candidate_logits,
274 int n_candidates,
275 Base base = Base::Nats
276) {
277 if (!candidate_logits || n_candidates == 0) {
278 return std::numeric_limits<float>::infinity();
279 }
280 if (n_candidates == 1) return 0.0f;
281
282 const float m = detail::max_finite(candidate_logits, n_candidates);
283 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
284
285 const float log_z = detail::log_sum_exp(candidate_logits, n_candidates, m);
286 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
287
288 float ez = 0.0f;
289 for (int i = 0; i < n_candidates; ++i) {
290 const float z = candidate_logits[i];
291 if (!std::isfinite(z)) continue;
292 const float p = std::exp(z - log_z);
293 ez += p * z;
294 }
295
296 const float h_nats = std::max(0.0f, log_z - ez);
297 return base == Base::Bits ? h_nats / detail::LN2 : h_nats;
298}
299
300// ============================================================================
301// Handle-based RollingPerplexity (supports clone for fork)
302// ============================================================================
303
332
339inline void add_surprisal(PerplexityHandle handle, float surprisal) {
340 auto& registry = detail::get_registry();
341 auto it = registry.find(handle);
342 if (it == registry.end()) return;
343 if (!std::isfinite(surprisal)) return;
344 it->second.nll_sum_nats += std::max(0.0f, surprisal);
345 it->second.count++;
346}
347
354inline float get_ppl(PerplexityHandle handle) {
355 auto& registry = detail::get_registry();
356 auto it = registry.find(handle);
357 if (it == registry.end() || it->second.count == 0) {
358 return std::numeric_limits<float>::infinity();
359 }
360 return std::exp(it->second.nll_sum_nats / static_cast<float>(it->second.count));
361}
362
369inline int get_count(PerplexityHandle handle) {
370 auto& registry = detail::get_registry();
371 auto it = registry.find(handle);
372 if (it == registry.end()) return 0;
373 return it->second.count;
374}
375
382 auto& registry = detail::get_registry();
383 auto it = registry.find(handle);
384 if (it != registry.end()) {
385 it->second = detail::PerplexityState{};
386 }
387}
388
405 auto& registry = detail::get_registry();
406 auto it = registry.find(handle);
407 if (it == registry.end()) return 0;
408
410 registry[new_handle] = it->second; // Copy state
411 return new_handle;
412}
413
420 detail::get_registry().erase(handle);
421}
422
423// ============================================================================
424// Branch-Level Metrics (unified model + sampling tracking)
425// ============================================================================
426
427using BranchMetricsHandle = int32_t;
428
429namespace detail {
430
432 PerplexityState model; // Model-level (raw logits before filters)
433 PerplexityState sampling; // Sampling-level (post top-k/p/temp)
434};
435
436inline std::unordered_map<BranchMetricsHandle, BranchMetricsState>&
438 static std::unordered_map<BranchMetricsHandle, BranchMetricsState> registry;
439 return registry;
440}
441
443 static BranchMetricsHandle next = 1;
444 return next;
445}
446
447} // namespace detail
448
462
471
482 auto& registry = detail::get_branch_metrics_registry();
483 auto it = registry.find(handle);
484 if (it == registry.end()) return 0;
485
487 registry[new_handle] = it->second; // Copy both model and sampling state
488 return new_handle;
489}
490
497inline void add_model_surprisal(BranchMetricsHandle handle, float surprisal) {
498 auto& registry = detail::get_branch_metrics_registry();
499 auto it = registry.find(handle);
500 if (it == registry.end()) return;
501 if (!std::isfinite(surprisal)) return;
502 it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
503 it->second.model.count++;
504}
505
512inline float get_model_ppl(BranchMetricsHandle handle) {
513 auto& registry = detail::get_branch_metrics_registry();
514 auto it = registry.find(handle);
515 if (it == registry.end() || it->second.model.count == 0) {
516 return std::numeric_limits<float>::infinity();
517 }
518 return std::exp(it->second.model.nll_sum_nats /
519 static_cast<float>(it->second.model.count));
520}
521
528inline void add_sampling_surprisal(BranchMetricsHandle handle, float surprisal) {
529 auto& registry = detail::get_branch_metrics_registry();
530 auto it = registry.find(handle);
531 if (it == registry.end()) return;
532 if (!std::isfinite(surprisal)) return;
533 it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
534 it->second.sampling.count++;
535}
536
544 auto& registry = detail::get_branch_metrics_registry();
545 auto it = registry.find(handle);
546 if (it == registry.end() || it->second.sampling.count == 0) {
547 return std::numeric_limits<float>::infinity();
548 }
549 return std::exp(it->second.sampling.nll_sum_nats /
550 static_cast<float>(it->second.sampling.count));
551}
552
557 auto& registry = detail::get_branch_metrics_registry();
558 auto it = registry.find(handle);
559 if (it == registry.end()) return 0;
560 return it->second.model.count;
561}
562
567 auto& registry = detail::get_branch_metrics_registry();
568 auto it = registry.find(handle);
569 if (it == registry.end()) return 0;
570 return it->second.sampling.count;
571}
572
573} // namespace lloyal::metrics
std::unordered_map< PerplexityHandle, PerplexityState > & get_registry()
Definition metrics.hpp:92
constexpr float LN2
Definition metrics.hpp:52
std::unordered_map< BranchMetricsHandle, BranchMetricsState > & get_branch_metrics_registry()
Definition metrics.hpp:437
float max_finite(const float *a, int n)
Find maximum finite value in array Used for log-sum-exp shift to prevent overflow.
Definition metrics.hpp:58
float log_sum_exp(const float *a, int n, float shift)
Numerically stable log-sum-exp Computes log(Σ exp(aᵢ)) using shift trick to avoid overflow.
Definition metrics.hpp:76
PerplexityHandle & get_next_handle()
Definition metrics.hpp:97
BranchMetricsHandle & get_next_branch_metrics_handle()
Definition metrics.hpp:442
float get_model_ppl(BranchMetricsHandle handle)
Get model-level perplexity (from raw logits)
Definition metrics.hpp:512
void add_model_surprisal(BranchMetricsHandle handle, float surprisal)
Add model-level surprisal (from raw logits before filters)
Definition metrics.hpp:497
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
Definition metrics.hpp:226
int get_sampling_count(BranchMetricsHandle handle)
Get number of tokens in sampling-level tracker.
Definition metrics.hpp:566
void add_surprisal(PerplexityHandle handle, float surprisal)
Add token surprisal to running average.
Definition metrics.hpp:339
void reset_perplexity(PerplexityHandle handle)
Reset tracker to initial state (start new sequence)
Definition metrics.hpp:381
float get_sampling_ppl(BranchMetricsHandle handle)
Get sampling-level perplexity (from filtered distribution)
Definition metrics.hpp:543
int32_t PerplexityHandle
Definition metrics.hpp:44
PerplexityHandle create_perplexity()
Definition metrics.hpp:327
BranchMetricsHandle clone_branch_metrics(BranchMetricsHandle handle)
Clone branch metrics tracker (for fork/branching)
Definition metrics.hpp:481
BranchMetricsHandle create_branch_metrics()
Create unified branch metrics tracker.
Definition metrics.hpp:457
float model_entropy(const float *logits, int n_vocab, Base base=Base::Nats)
Definition metrics.hpp:180
float sampling_entropy(const float *candidate_logits, int n_candidates, Base base=Base::Nats)
Compute sampling-level entropy of candidate distribution.
Definition metrics.hpp:272
int get_count(PerplexityHandle handle)
Get number of tokens added to tracker.
Definition metrics.hpp:369
float get_ppl(PerplexityHandle handle)
Get current perplexity.
Definition metrics.hpp:354
void free_branch_metrics(BranchMetricsHandle handle)
Free branch metrics tracker.
Definition metrics.hpp:468
PerplexityHandle clone_perplexity(PerplexityHandle handle)
Definition metrics.hpp:404
void free_perplexity(PerplexityHandle handle)
Free perplexity tracker.
Definition metrics.hpp:419
int get_model_count(BranchMetricsHandle handle)
Get number of tokens in model-level tracker.
Definition metrics.hpp:556
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
Definition metrics.hpp:131
void add_sampling_surprisal(BranchMetricsHandle handle, float surprisal)
Add sampling-level surprisal (from filtered distribution)
Definition metrics.hpp:528
int32_t BranchMetricsHandle
Definition metrics.hpp:427