liblloyal 1.0.0
Branched Inference for llama.cpp
Loading...
Searching...
No Matches
metrics.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6
34#include <algorithm>
35#include <cmath>
36#include <cstdint>
37#include <limits>
38
39namespace lloyal::metrics {
40
41// ============================================================================
42// Types
43// ============================================================================
44
45enum class Base { Nats, Bits };
46
47// ============================================================================
48// Internal helpers (ported from metrics.ts)
49// ============================================================================
50
51namespace detail {
52
53constexpr float LN2 = 0.693147180559945309417232121458176568f;
54
59inline float max_finite(const float* a, int n) {
60 float m = -std::numeric_limits<float>::infinity();
61 for (int i = 0; i < n; ++i) {
62 const float v = a[i];
63 if (std::isfinite(v) && v > m) m = v;
64 }
65 return m;
66}
67
77inline float log_sum_exp(const float* a, int n, float shift) {
78 float s = 0.0f;
79 for (int i = 0; i < n; ++i) {
80 const float v = a[i];
81 if (std::isfinite(v)) s += std::exp(v - shift);
82 }
83 if (s == 0.0f) return -std::numeric_limits<float>::infinity();
84 return shift + std::log(s);
85}
86
87} // namespace detail
88
89// ============================================================================
90// Perplexity tracking types (used by BranchStore registry)
91// ============================================================================
92
95 float nll_sum_nats = 0.0f;
96 int count = 0;
97};
98
104
105// ============================================================================
106// Model-level metrics (raw logits, before filters)
107// ============================================================================
108
132inline float model_surprisal(
133 const float* logits,
134 int n_vocab,
135 int picked_id,
136 Base base = Base::Nats
137) {
138 if (!logits || n_vocab == 0) {
139 return std::numeric_limits<float>::infinity();
140 }
141 if (picked_id < 0 || picked_id >= n_vocab) {
142 return std::numeric_limits<float>::infinity();
143 }
144
145 const float picked = logits[picked_id];
146 if (!std::isfinite(picked)) return std::numeric_limits<float>::infinity();
147
148 const float m = detail::max_finite(logits, n_vocab);
149 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
150
151 const float log_z = detail::log_sum_exp(logits, n_vocab, m);
152 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
153
154 const float surprisal_nats = std::max(0.0f, -(picked - log_z));
155 return base == Base::Bits ? surprisal_nats / detail::LN2 : surprisal_nats;
156}
157
181inline float model_entropy(
182 const float* logits,
183 int n_vocab,
184 Base base = Base::Nats
185) {
186 if (!logits || n_vocab == 0) {
187 return std::numeric_limits<float>::infinity();
188 }
189
190 const float m = detail::max_finite(logits, n_vocab);
191 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
192
193 const float log_z = detail::log_sum_exp(logits, n_vocab, m);
194 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
195
196 float ez = 0.0f;
197 for (int i = 0; i < n_vocab; ++i) {
198 const float z = logits[i];
199 if (!std::isfinite(z)) continue;
200 const float p = std::exp(z - log_z);
201 ez += p * z;
202 }
203
204 const float h_nats = std::max(0.0f, log_z - ez);
205 return base == Base::Bits ? h_nats / detail::LN2 : h_nats;
206}
207
208// ============================================================================
209// Sampling-level metrics (post-filter logits, after top-k/p/temp)
210// ============================================================================
211
228 const float* candidate_logits,
229 const int32_t* candidate_ids,
230 int n_candidates,
231 int picked_id,
232 Base base = Base::Nats
233) {
234 if (!candidate_logits || !candidate_ids || n_candidates == 0) {
235 return std::numeric_limits<float>::infinity();
236 }
237
238 // Find picked_id in candidates
239 int local = -1;
240 for (int i = 0; i < n_candidates; ++i) {
241 if (candidate_ids[i] == picked_id) {
242 local = i;
243 break;
244 }
245 }
246 if (local == -1) return std::numeric_limits<float>::infinity();
247 if (n_candidates == 1) return 0.0f;
248
249 const float picked = candidate_logits[local];
250 if (!std::isfinite(picked)) return std::numeric_limits<float>::infinity();
251
252 const float m = detail::max_finite(candidate_logits, n_candidates);
253 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
254
255 const float log_z = detail::log_sum_exp(candidate_logits, n_candidates, m);
256 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
257
258 const float surprisal_nats = std::max(0.0f, -(picked - log_z));
259 return base == Base::Bits ? surprisal_nats / detail::LN2 : surprisal_nats;
260}
261
273inline float sampling_entropy(
274 const float* candidate_logits,
275 int n_candidates,
276 Base base = Base::Nats
277) {
278 if (!candidate_logits || n_candidates == 0) {
279 return std::numeric_limits<float>::infinity();
280 }
281 if (n_candidates == 1) return 0.0f;
282
283 const float m = detail::max_finite(candidate_logits, n_candidates);
284 if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
285
286 const float log_z = detail::log_sum_exp(candidate_logits, n_candidates, m);
287 if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
288
289 float ez = 0.0f;
290 for (int i = 0; i < n_candidates; ++i) {
291 const float z = candidate_logits[i];
292 if (!std::isfinite(z)) continue;
293 const float p = std::exp(z - log_z);
294 ez += p * z;
295 }
296
297 const float h_nats = std::max(0.0f, log_z - ez);
298 return base == Base::Bits ? h_nats / detail::LN2 : h_nats;
299}
300
301} // namespace lloyal::metrics
constexpr float LN2
Definition metrics.hpp:53
float max_finite(const float *a, int n)
Find maximum finite value in array Used for log-sum-exp shift to prevent overflow.
Definition metrics.hpp:59
float log_sum_exp(const float *a, int n, float shift)
Numerically stable log-sum-exp Computes log(Σ exp(aᵢ)) using shift trick to avoid overflow.
Definition metrics.hpp:77
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
Definition metrics.hpp:227
float model_entropy(const float *logits, int n_vocab, Base base=Base::Nats)
Definition metrics.hpp:181
float sampling_entropy(const float *candidate_logits, int n_candidates, Base base=Base::Nats)
Compute sampling-level entropy of candidate distribution.
Definition metrics.hpp:273
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
Definition metrics.hpp:132
Unified model + sampling perplexity tracker.
Definition metrics.hpp:100
PerplexityState model
Model-level (raw logits before filters)
Definition metrics.hpp:101
PerplexityState sampling
Sampling-level (post top-k/p/temp)
Definition metrics.hpp:102
Rolling NLL accumulator for perplexity computation.
Definition metrics.hpp:94