liblloyal 1.0.0
Composable primitives for llama.cpp inference
Loading...
Searching...
No Matches
/home/runner/work/liblloyal/liblloyal/include/lloyal/metrics.hpp

Compute model-level surprisal for picked token.

Compute model-level surprisal for picked tokenSurprisal = -log p(tokenₜ | context) = uncertainty of the model's choice Higher surprisal = more surprising token (lower probability)

Use model logits (before temperature/top-k/p) to measure model's inherent uncertainty.

Parameters
logitsFull vocabulary logits (before sampling filters)
n_vocabVocabulary size
picked_idToken ID that was sampled
baseNats (natural log) or Bits (log₂)
Returns
Surprisal in nats or bits (≥0, Infinity if invalid)

float* logits = lloyal::logits::get(ctx); int n_vocab = llama_vocab_n_tokens(vocab); llama_token token = sample(logits); float s = metrics::model_surprisal(logits, n_vocab, token); if (s > 5.0f) { // High uncertainty - consider retrieval }

#pragma once
// SPDX-License-Identifier: Apache-2.0
// Copyright 2026 Lloyal Labs
#include <cmath>
#include <cstdint>
#include <limits>
#include <unordered_map>
#include <vector>
namespace lloyal::metrics {
// ============================================================================
// Types
// ============================================================================
enum class Base { Nats, Bits };
using PerplexityHandle = int32_t;
// ============================================================================
// Internal helpers (ported from metrics.ts)
// ============================================================================
namespace detail {
constexpr float LN2 = 0.693147180559945309417232121458176568f;
inline float max_finite(const float* a, int n) {
float m = -std::numeric_limits<float>::infinity();
for (int i = 0; i < n; ++i) {
const float v = a[i];
if (std::isfinite(v) && v > m) m = v;
}
return m;
}
inline float log_sum_exp(const float* a, int n, float shift) {
float s = 0.0f;
for (int i = 0; i < n; ++i) {
const float v = a[i];
if (std::isfinite(v)) s += std::exp(v - shift);
}
if (s == 0.0f) return -std::numeric_limits<float>::infinity();
return shift + std::log(s);
}
// Perplexity state for handle-based tracking
struct PerplexityState {
float nll_sum_nats = 0.0f;
int count = 0;
};
inline std::unordered_map<PerplexityHandle, PerplexityState>& get_registry() {
static std::unordered_map<PerplexityHandle, PerplexityState> registry;
return registry;
}
static PerplexityHandle next = 1;
return next;
}
} // namespace detail
// ============================================================================
// Model-level metrics (raw logits, before filters)
// ============================================================================
inline float model_surprisal(
const float* logits,
int n_vocab,
int picked_id,
) {
if (!logits || n_vocab == 0) {
return std::numeric_limits<float>::infinity();
}
if (picked_id < 0 || picked_id >= n_vocab) {
return std::numeric_limits<float>::infinity();
}
const float picked = logits[picked_id];
if (!std::isfinite(picked)) return std::numeric_limits<float>::infinity();
const float m = detail::max_finite(logits, n_vocab);
if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
const float log_z = detail::log_sum_exp(logits, n_vocab, m);
if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
const float surprisal_nats = std::max(0.0f, -(picked - log_z));
return base == Base::Bits ? surprisal_nats / detail::LN2 : surprisal_nats;
}
inline float model_entropy(
const float* logits,
int n_vocab,
) {
if (!logits || n_vocab == 0) {
return std::numeric_limits<float>::infinity();
}
const float m = detail::max_finite(logits, n_vocab);
if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
const float log_z = detail::log_sum_exp(logits, n_vocab, m);
if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
float ez = 0.0f;
for (int i = 0; i < n_vocab; ++i) {
const float z = logits[i];
if (!std::isfinite(z)) continue;
const float p = std::exp(z - log_z);
ez += p * z;
}
const float h_nats = std::max(0.0f, log_z - ez);
return base == Base::Bits ? h_nats / detail::LN2 : h_nats;
}
// ============================================================================
// Sampling-level metrics (post-filter logits, after top-k/p/temp)
// ============================================================================
inline float sampling_surprisal(
const float* candidate_logits,
const int32_t* candidate_ids,
int n_candidates,
int picked_id,
) {
if (!candidate_logits || !candidate_ids || n_candidates == 0) {
return std::numeric_limits<float>::infinity();
}
// Find picked_id in candidates
int local = -1;
for (int i = 0; i < n_candidates; ++i) {
if (candidate_ids[i] == picked_id) {
local = i;
break;
}
}
if (local == -1) return std::numeric_limits<float>::infinity();
if (n_candidates == 1) return 0.0f;
const float picked = candidate_logits[local];
if (!std::isfinite(picked)) return std::numeric_limits<float>::infinity();
const float m = detail::max_finite(candidate_logits, n_candidates);
if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
const float log_z = detail::log_sum_exp(candidate_logits, n_candidates, m);
if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
const float surprisal_nats = std::max(0.0f, -(picked - log_z));
return base == Base::Bits ? surprisal_nats / detail::LN2 : surprisal_nats;
}
inline float sampling_entropy(
const float* candidate_logits,
int n_candidates,
) {
if (!candidate_logits || n_candidates == 0) {
return std::numeric_limits<float>::infinity();
}
if (n_candidates == 1) return 0.0f;
const float m = detail::max_finite(candidate_logits, n_candidates);
if (!std::isfinite(m)) return std::numeric_limits<float>::infinity();
const float log_z = detail::log_sum_exp(candidate_logits, n_candidates, m);
if (!std::isfinite(log_z)) return std::numeric_limits<float>::infinity();
float ez = 0.0f;
for (int i = 0; i < n_candidates; ++i) {
const float z = candidate_logits[i];
if (!std::isfinite(z)) continue;
const float p = std::exp(z - log_z);
ez += p * z;
}
const float h_nats = std::max(0.0f, log_z - ez);
return base == Base::Bits ? h_nats / detail::LN2 : h_nats;
}
// ============================================================================
// Handle-based RollingPerplexity (supports clone for fork)
// ============================================================================
detail::get_registry()[h] = detail::PerplexityState{};
return h;
}
inline void add_surprisal(PerplexityHandle handle, float surprisal) {
auto& registry = detail::get_registry();
auto it = registry.find(handle);
if (it == registry.end()) return;
if (!std::isfinite(surprisal)) return;
it->second.nll_sum_nats += std::max(0.0f, surprisal);
it->second.count++;
}
inline float get_ppl(PerplexityHandle handle) {
auto& registry = detail::get_registry();
auto it = registry.find(handle);
if (it == registry.end() || it->second.count == 0) {
return std::numeric_limits<float>::infinity();
}
return std::exp(it->second.nll_sum_nats / static_cast<float>(it->second.count));
}
inline int get_count(PerplexityHandle handle) {
auto& registry = detail::get_registry();
auto it = registry.find(handle);
if (it == registry.end()) return 0;
return it->second.count;
}
inline void reset_perplexity(PerplexityHandle handle) {
auto& registry = detail::get_registry();
auto it = registry.find(handle);
if (it != registry.end()) {
it->second = detail::PerplexityState{};
}
}
auto& registry = detail::get_registry();
auto it = registry.find(handle);
if (it == registry.end()) return 0;
registry[new_handle] = it->second; // Copy state
return new_handle;
}
inline void free_perplexity(PerplexityHandle handle) {
detail::get_registry().erase(handle);
}
// ============================================================================
// Branch-Level Metrics (unified model + sampling tracking)
// ============================================================================
using BranchMetricsHandle = int32_t;
namespace detail {
struct BranchMetricsState {
PerplexityState model; // Model-level (raw logits before filters)
PerplexityState sampling; // Sampling-level (post top-k/p/temp)
};
inline std::unordered_map<BranchMetricsHandle, BranchMetricsState>&
static std::unordered_map<BranchMetricsHandle, BranchMetricsState> registry;
return registry;
}
static BranchMetricsHandle next = 1;
return next;
}
} // namespace detail
detail::get_branch_metrics_registry()[h] = detail::BranchMetricsState{};
return h;
}
}
auto it = registry.find(handle);
if (it == registry.end()) return 0;
registry[new_handle] = it->second; // Copy both model and sampling state
return new_handle;
}
inline void add_model_surprisal(BranchMetricsHandle handle, float surprisal) {
auto it = registry.find(handle);
if (it == registry.end()) return;
if (!std::isfinite(surprisal)) return;
it->second.model.nll_sum_nats += std::max(0.0f, surprisal);
it->second.model.count++;
}
inline float get_model_ppl(BranchMetricsHandle handle) {
auto it = registry.find(handle);
if (it == registry.end() || it->second.model.count == 0) {
return std::numeric_limits<float>::infinity();
}
return std::exp(it->second.model.nll_sum_nats /
static_cast<float>(it->second.model.count));
}
inline void add_sampling_surprisal(BranchMetricsHandle handle, float surprisal) {
auto it = registry.find(handle);
if (it == registry.end()) return;
if (!std::isfinite(surprisal)) return;
it->second.sampling.nll_sum_nats += std::max(0.0f, surprisal);
it->second.sampling.count++;
}
inline float get_sampling_ppl(BranchMetricsHandle handle) {
auto it = registry.find(handle);
if (it == registry.end() || it->second.sampling.count == 0) {
return std::numeric_limits<float>::infinity();
}
return std::exp(it->second.sampling.nll_sum_nats /
static_cast<float>(it->second.sampling.count));
}
auto it = registry.find(handle);
if (it == registry.end()) return 0;
return it->second.model.count;
}
auto it = registry.find(handle);
if (it == registry.end()) return 0;
return it->second.sampling.count;
}
} // namespace lloyal::metrics
std::unordered_map< PerplexityHandle, PerplexityState > & get_registry()
Definition metrics.hpp:92
constexpr float LN2
Definition metrics.hpp:52
std::unordered_map< BranchMetricsHandle, BranchMetricsState > & get_branch_metrics_registry()
Definition metrics.hpp:437
float max_finite(const float *a, int n)
Find maximum finite value in array Used for log-sum-exp shift to prevent overflow.
Definition metrics.hpp:58
float log_sum_exp(const float *a, int n, float shift)
Numerically stable log-sum-exp Computes log(Σ exp(aᵢ)) using shift trick to avoid overflow.
Definition metrics.hpp:76
PerplexityHandle & get_next_handle()
Definition metrics.hpp:97
BranchMetricsHandle & get_next_branch_metrics_handle()
Definition metrics.hpp:442
float get_model_ppl(BranchMetricsHandle handle)
Get model-level perplexity (from raw logits)
Definition metrics.hpp:512
void add_model_surprisal(BranchMetricsHandle handle, float surprisal)
Add model-level surprisal (from raw logits before filters)
Definition metrics.hpp:497
float sampling_surprisal(const float *candidate_logits, const int32_t *candidate_ids, int n_candidates, int picked_id, Base base=Base::Nats)
Compute sampling-level surprisal for picked token.
Definition metrics.hpp:226
int get_sampling_count(BranchMetricsHandle handle)
Get number of tokens in sampling-level tracker.
Definition metrics.hpp:566
void add_surprisal(PerplexityHandle handle, float surprisal)
Add token surprisal to running average.
Definition metrics.hpp:339
void reset_perplexity(PerplexityHandle handle)
Reset tracker to initial state (start new sequence)
Definition metrics.hpp:381
float get_sampling_ppl(BranchMetricsHandle handle)
Get sampling-level perplexity (from filtered distribution)
Definition metrics.hpp:543
int32_t PerplexityHandle
Definition metrics.hpp:44
PerplexityHandle create_perplexity()
Definition metrics.hpp:327
BranchMetricsHandle clone_branch_metrics(BranchMetricsHandle handle)
Clone branch metrics tracker (for fork/branching)
Definition metrics.hpp:481
BranchMetricsHandle create_branch_metrics()
Create unified branch metrics tracker.
Definition metrics.hpp:457
float model_entropy(const float *logits, int n_vocab, Base base=Base::Nats)
Definition metrics.hpp:180
float sampling_entropy(const float *candidate_logits, int n_candidates, Base base=Base::Nats)
Compute sampling-level entropy of candidate distribution.
Definition metrics.hpp:272
int get_count(PerplexityHandle handle)
Get number of tokens added to tracker.
Definition metrics.hpp:369
float get_ppl(PerplexityHandle handle)
Get current perplexity.
Definition metrics.hpp:354
void free_branch_metrics(BranchMetricsHandle handle)
Free branch metrics tracker.
Definition metrics.hpp:468
PerplexityHandle clone_perplexity(PerplexityHandle handle)
Definition metrics.hpp:404
void free_perplexity(PerplexityHandle handle)
Free perplexity tracker.
Definition metrics.hpp:419
int get_model_count(BranchMetricsHandle handle)
Get number of tokens in model-level tracker.
Definition metrics.hpp:556
float model_surprisal(const float *logits, int n_vocab, int picked_id, Base base=Base::Nats)
Definition metrics.hpp:131
void add_sampling_surprisal(BranchMetricsHandle handle, float surprisal)
Add sampling-level surprisal (from filtered distribution)
Definition metrics.hpp:528
int32_t BranchMetricsHandle
Definition metrics.hpp:427