33#include <unordered_map>
52constexpr float LN2 = 0.693147180559945309417232121458176568f;
59 float m = -std::numeric_limits<float>::infinity();
60 for (
int i = 0; i < n; ++i) {
62 if (std::isfinite(v) && v > m) m = v;
78 for (
int i = 0; i < n; ++i) {
80 if (std::isfinite(v)) s += std::exp(v - shift);
82 if (s == 0.0f)
return -std::numeric_limits<float>::infinity();
83 return shift + std::log(s);
92inline std::unordered_map<PerplexityHandle, PerplexityState>&
get_registry() {
93 static std::unordered_map<PerplexityHandle, PerplexityState> registry;
137 if (!logits || n_vocab == 0) {
138 return std::numeric_limits<float>::infinity();
140 if (picked_id < 0 || picked_id >= n_vocab) {
141 return std::numeric_limits<float>::infinity();
144 const float picked = logits[picked_id];
145 if (!std::isfinite(picked))
return std::numeric_limits<float>::infinity();
148 if (!std::isfinite(m))
return std::numeric_limits<float>::infinity();
151 if (!std::isfinite(log_z))
return std::numeric_limits<float>::infinity();
153 const float surprisal_nats = std::max(0.0f, -(picked - log_z));
185 if (!logits || n_vocab == 0) {
186 return std::numeric_limits<float>::infinity();
190 if (!std::isfinite(m))
return std::numeric_limits<float>::infinity();
193 if (!std::isfinite(log_z))
return std::numeric_limits<float>::infinity();
196 for (
int i = 0; i < n_vocab; ++i) {
197 const float z = logits[i];
198 if (!std::isfinite(z))
continue;
199 const float p = std::exp(z - log_z);
203 const float h_nats = std::max(0.0f, log_z - ez);
227 const float* candidate_logits,
228 const int32_t* candidate_ids,
233 if (!candidate_logits || !candidate_ids || n_candidates == 0) {
234 return std::numeric_limits<float>::infinity();
239 for (
int i = 0; i < n_candidates; ++i) {
240 if (candidate_ids[i] == picked_id) {
245 if (local == -1)
return std::numeric_limits<float>::infinity();
246 if (n_candidates == 1)
return 0.0f;
248 const float picked = candidate_logits[local];
249 if (!std::isfinite(picked))
return std::numeric_limits<float>::infinity();
252 if (!std::isfinite(m))
return std::numeric_limits<float>::infinity();
255 if (!std::isfinite(log_z))
return std::numeric_limits<float>::infinity();
257 const float surprisal_nats = std::max(0.0f, -(picked - log_z));
273 const float* candidate_logits,
277 if (!candidate_logits || n_candidates == 0) {
278 return std::numeric_limits<float>::infinity();
280 if (n_candidates == 1)
return 0.0f;
283 if (!std::isfinite(m))
return std::numeric_limits<float>::infinity();
286 if (!std::isfinite(log_z))
return std::numeric_limits<float>::infinity();
289 for (
int i = 0; i < n_candidates; ++i) {
290 const float z = candidate_logits[i];
291 if (!std::isfinite(z))
continue;
292 const float p = std::exp(z - log_z);
296 const float h_nats = std::max(0.0f, log_z - ez);
341 auto it = registry.find(handle);
342 if (it == registry.end())
return;
343 if (!std::isfinite(surprisal))
return;
344 it->second.nll_sum_nats += std::max(0.0f, surprisal);
356 auto it = registry.find(handle);
357 if (it == registry.end() || it->second.count == 0) {
358 return std::numeric_limits<float>::infinity();
360 return std::exp(it->second.nll_sum_nats /
static_cast<float>(it->second.count));
371 auto it = registry.find(handle);
372 if (it == registry.end())
return 0;
373 return it->second.count;
383 auto it = registry.find(handle);
384 if (it != registry.end()) {
406 auto it = registry.find(handle);
407 if (it == registry.end())
return 0;
410 registry[new_handle] = it->second;
436inline std::unordered_map<BranchMetricsHandle, BranchMetricsState>&
438 static std::unordered_map<BranchMetricsHandle, BranchMetricsState> registry;
483 auto it = registry.find(handle);
484 if (it == registry.end())
return 0;
487 registry[new_handle] = it->second;
499 auto it = registry.find(handle);
500 if (it == registry.end())
return;
501 if (!std::isfinite(surprisal))
return;
502 it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
503 it->second.model.count++;
514 auto it = registry.find(handle);
515 if (it == registry.end() || it->second.model.count == 0) {
516 return std::numeric_limits<float>::infinity();
518 return std::exp(it->second.model.nll_sum_nats /
519 static_cast<float>(it->second.model.count));
530 auto it = registry.find(handle);
531 if (it == registry.end())
return;
532 if (!std::isfinite(surprisal))
return;
533 it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
534 it->second.sampling.count++;
545 auto it = registry.find(handle);
546 if (it == registry.end() || it->second.sampling.count == 0) {
547 return std::numeric_limits<float>::infinity();
549 return std::exp(it->second.sampling.nll_sum_nats /
550 static_cast<float>(it->second.sampling.count));
558 auto it = registry.find(handle);
559 if (it == registry.end())
return 0;
560 return it->second.model.count;
568 auto it = registry.find(handle);
569 if (it == registry.end())
return 0;
570 return it->second.sampling.count;
std::unordered_map< PerplexityHandle, PerplexityState > & get_registry()
std::unordered_map< BranchMetricsHandle, BranchMetricsState > & get_branch_metrics_registry()
float max_finite(const float *a, int n)
Find maximum finite value in array Used for log-sum-exp shift to prevent overflow.
float log_sum_exp(const float *a, int n, float shift)
Numerically stable log-sum-exp Computes log(Σ exp(aᵢ)) using shift trick to avoid overflow.
PerplexityHandle & get_next_handle()
BranchMetricsHandle & get_next_branch_metrics_handle()
float get_model_ppl(BranchMetricsHandle handle)
Get model-level perplexity (from raw logits)
void add_model_surprisal(BranchMetricsHandle handle, float surprisal)
Add model-level surprisal (from raw logits before filters)
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
int get_sampling_count(BranchMetricsHandle handle)
Get number of tokens in sampling-level tracker.
void add_surprisal(PerplexityHandle handle, float surprisal)
Add token surprisal to running average.
void reset_perplexity(PerplexityHandle handle)
Reset tracker to initial state (start new sequence)
float get_sampling_ppl(BranchMetricsHandle handle)
Get sampling-level perplexity (from filtered distribution)
PerplexityHandle create_perplexity()
BranchMetricsHandle clone_branch_metrics(BranchMetricsHandle handle)
Clone branch metrics tracker (for fork/branching)
BranchMetricsHandle create_branch_metrics()
Create unified branch metrics tracker.
float model_entropy(const float *logits, int n_vocab, Base base=Base::Nats)
float sampling_entropy(const float *candidate_logits, int n_candidates, Base base=Base::Nats)
Compute sampling-level entropy of candidate distribution.
int get_count(PerplexityHandle handle)
Get number of tokens added to tracker.
float get_ppl(PerplexityHandle handle)
Get current perplexity.
void free_branch_metrics(BranchMetricsHandle handle)
Free branch metrics tracker.
PerplexityHandle clone_perplexity(PerplexityHandle handle)
void free_perplexity(PerplexityHandle handle)
Free perplexity tracker.
int get_model_count(BranchMetricsHandle handle)
Get number of tokens in model-level tracker.
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
void add_sampling_surprisal(BranchMetricsHandle handle, float surprisal)
Add sampling-level surprisal (from filtered distribution)
int32_t BranchMetricsHandle