liblloyal 1.0.0
Composable primitives for llama.cpp inference
Loading...
Searching...
No Matches
embedding.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6#include "common.hpp"
7#include "helpers.hpp"
8#include <algorithm>
9#include <cmath>
10#include <cstdint>
11#include <llama/llama.h>
12#include <stdexcept>
13#include <vector>
14
41
42// ===== NORMALIZATION MODES =====
43
47enum class Normalize : int32_t {
48 None = 0, // No normalization (raw embeddings)
49 L2 = 1, // L2 normalization (unit length, required for cosine similarity)
50};
51
52// ===== MODEL CAPABILITY CHECKS =====
53
63inline bool has_embeddings(const llama_model *model) {
64 if (!model) {
65 LLOYAL_LOG_DEBUG("[embedding::has_embeddings] ERROR: model is null");
66 return false;
67 }
68
70 return n_embd > 0;
71}
72
79inline int32_t dimension(const llama_model *model) {
80 if (!model) {
81 LLOYAL_LOG_DEBUG("[embedding::dimension] ERROR: model is null");
82 return 0;
83 }
84
85 return llama_model_n_embd(model);
86}
87
88// ===== CONTEXT CAPABILITY CHECKS =====
89
100 if (!ctx) {
101 LLOYAL_LOG_DEBUG("[embedding::has_pooling] ERROR: ctx is null");
102 return false;
103 }
104
106}
107
121 if (!ctx) {
122 LLOYAL_LOG_DEBUG("[embedding::pooling_type] ERROR: ctx is null");
124 }
125
126 return llama_pooling_type(ctx);
127}
128
129// ===== INTERNAL HELPERS =====
130
131namespace detail {
132
139inline void apply_l2_normalize(std::vector<float> &vec) {
140 if (vec.empty())
141 return;
142
143 float norm_sq = 0.0f;
144 for (float v : vec) {
145 norm_sq += v * v;
146 }
147
148 float norm = std::sqrt(norm_sq);
149 if (norm > 1e-8f) { // Avoid division by zero
150 for (float &v : vec) {
151 v /= norm;
152 }
153 } else {
155 "[embedding::detail::apply_l2_normalize] WARNING: near-zero norm");
156 }
157}
158
159} // namespace detail
160
161// ===== RAII GUARD FOR BATCH CLEANUP =====
162
163namespace detail {
173} // namespace detail
174
175// ===== ENCODING (FORWARD PASS FOR EMBEDDINGS) =====
176
202inline void encode(llama_context *ctx, const llama_token *tokens,
204 LLOYAL_LOG_DEBUG("[embedding::encode] Encoding %d tokens for embeddings",
205 n_tokens);
206
207 if (!ctx) {
208 LLOYAL_LOG_DEBUG("[embedding::encode] ERROR: NULL context");
209 throw std::runtime_error("embedding::encode - NULL context");
210 }
211
212 if (!tokens || n_tokens <= 0) {
213 LLOYAL_LOG_DEBUG("[embedding::encode] ERROR: Invalid token array");
214 throw std::runtime_error("embedding::encode - Invalid token array");
215 }
216
217 if (n_tokens > n_batch) {
218 LLOYAL_LOG_DEBUG("[embedding::encode] ERROR: n_tokens (%d) > n_batch (%d)",
220 throw std::runtime_error(
221 "embedding::encode - token count exceeds batch size (truncation not "
222 "supported, increase n_batch or reduce input length)");
223 }
224
225 // Initialize batch - single sequence
226 llama_batch batch = llama_batch_init(n_batch, 0, 1);
228
229 // Clear batch
230 lloyal::batch_clear(batch);
231
232 // Add ALL tokens with logits=true (required for embedding extraction)
233 for (int32_t i = 0; i < n_tokens; ++i) {
234 lloyal::batch_add(batch, tokens[i], i, {0}, true, n_batch);
235 }
236
237 // Decode/encode the batch (llama.cpp handles encoder vs decoder internally)
238 if (llama_decode(ctx, batch) != 0) {
239 LLOYAL_LOG_DEBUG("[embedding::encode] ERROR: llama_decode failed");
240 throw std::runtime_error("embedding::encode - llama_decode failed");
241 }
242
243 LLOYAL_LOG_DEBUG("[embedding::encode] Encode complete");
244}
245
249inline void encode(llama_context *ctx, const std::vector<llama_token> &tokens,
251 encode(ctx, tokens.data(), static_cast<int32_t>(tokens.size()), n_batch);
252}
253
254// ===== EMBEDDING EXTRACTION =====
255
271inline std::vector<float> get(llama_context *ctx,
273 if (!ctx) {
274 LLOYAL_LOG_DEBUG("[embedding::get] ERROR: ctx is null");
275 throw std::invalid_argument("embedding::get: ctx is null");
276 }
277
278 // Get model to determine embedding dimension
279 const llama_model *model = llama_get_model(ctx);
280 if (!model) {
281 LLOYAL_LOG_DEBUG("[embedding::get] ERROR: failed to get model from context");
282 throw std::runtime_error("embedding::get: failed to get model");
283 }
284
285 // Warn if pooling not enabled (embeddings may be invalid)
286 if (!has_pooling(ctx)) {
288 "[embedding::get] WARNING: pooling not enabled, embeddings may be "
289 "invalid. Create context with pooling_type != NONE");
290 }
291
292 // Get embeddings pointer from llama.cpp
293 // For pooled embeddings, use sequence-specific API (sequence 0)
294 const float *embd_ptr = nullptr;
295 if (has_pooling(ctx)) {
297 LLOYAL_LOG_DEBUG("[embedding::get] Using llama_get_embeddings_seq for pooled "
298 "embeddings");
299 } else {
301 LLOYAL_LOG_DEBUG("[embedding::get] Using llama_get_embeddings (no pooling)");
302 }
303
304 if (!embd_ptr) {
305 LLOYAL_LOG_DEBUG("[embedding::get] ERROR: embeddings pointer is null. "
306 "Ensure context was created with embeddings=true and "
307 "tokens were encoded with logits=true for all tokens.");
308 throw std::runtime_error(
309 "embedding::get: embeddings unavailable (ensure embeddings=true in "
310 "context params and use encode_for_embeddings())");
311 }
312
313 // Copy to vector
315 std::vector<float> embeddings(embd_ptr, embd_ptr + n_embd);
316
317 // Apply normalization
318 if (normalize == Normalize::L2) {
320 }
321
322 LLOYAL_LOG_DEBUG("[embedding::get] Extracted embeddings (dim=%d, normalize=%d)",
323 n_embd, static_cast<int>(normalize));
324
325 return embeddings;
326}
327
341inline std::vector<float> get_seq(llama_context *ctx, llama_seq_id seq,
343 if (!ctx) {
344 LLOYAL_LOG_DEBUG("[embedding::get_seq] ERROR: ctx is null");
345 throw std::invalid_argument("embedding::get_seq: ctx is null");
346 }
347
348 const llama_model *model = llama_get_model(ctx);
349 if (!model) {
350 LLOYAL_LOG_DEBUG("[embedding::get_seq] ERROR: failed to get model");
351 throw std::runtime_error("embedding::get_seq: failed to get model");
352 }
353
354 if (!has_pooling(ctx)) {
355 LLOYAL_LOG_DEBUG("[embedding::get_seq] WARNING: pooling not enabled");
356 }
357
358 // Try sequence-specific API
359 const float *embd_ptr = llama_get_embeddings_seq(ctx, seq);
360
361 // Fallback to global embeddings for seq=0
362 if (!embd_ptr) {
363 if (seq == 0) {
364 LLOYAL_LOG_DEBUG("[embedding::get_seq] Falling back to get() for seq=0");
365 return get(ctx, normalize);
366 }
367 LLOYAL_LOG_DEBUG("[embedding::get_seq] ERROR: embeddings unavailable for "
368 "seq=%d",
369 seq);
370 throw std::runtime_error("embedding::get_seq: embeddings unavailable");
371 }
372
374 std::vector<float> embeddings(embd_ptr, embd_ptr + n_embd);
375
376 if (normalize == Normalize::L2) {
378 }
379
380 LLOYAL_LOG_DEBUG("[embedding::get_seq] Extracted embeddings for seq=%d "
381 "(dim=%d)",
382 seq, n_embd);
383
384 return embeddings;
385}
386
400inline std::vector<float> get_ith(llama_context *ctx, int32_t idx,
402 if (!ctx) {
403 LLOYAL_LOG_DEBUG("[embedding::get_ith] ERROR: ctx is null");
404 throw std::invalid_argument("embedding::get_ith: ctx is null");
405 }
406
407 const llama_model *model = llama_get_model(ctx);
408 if (!model) {
409 LLOYAL_LOG_DEBUG("[embedding::get_ith] ERROR: failed to get model");
410 throw std::runtime_error("embedding::get_ith: failed to get model");
411 }
412
413 const float *embd_ptr = llama_get_embeddings_ith(ctx, idx);
414 if (!embd_ptr) {
415 LLOYAL_LOG_DEBUG("[embedding::get_ith] ERROR: embeddings unavailable for "
416 "idx=%d",
417 idx);
418 throw std::runtime_error("embedding::get_ith: embeddings unavailable");
419 }
420
422 std::vector<float> embeddings(embd_ptr, embd_ptr + n_embd);
423
424 if (normalize == Normalize::L2) {
426 }
427
428 LLOYAL_LOG_DEBUG("[embedding::get_ith] Extracted embeddings for idx=%d "
429 "(dim=%d)",
430 idx, n_embd);
431
432 return embeddings;
433}
434
435// ===== SIMILARITY =====
436
451inline float cosine_similarity(const std::vector<float> &a,
452 const std::vector<float> &b) {
453 if (a.size() != b.size()) {
454 LLOYAL_LOG_DEBUG("[embedding::cosine_similarity] ERROR: dimension mismatch "
455 "(%zu vs %zu)",
456 a.size(), b.size());
457 throw std::invalid_argument(
458 "embedding::cosine_similarity: dimension mismatch");
459 }
460
461 if (a.empty()) {
462 return 0.0f;
463 }
464
465 // For L2-normalized vectors, cosine similarity = dot product
466 float dot = 0.0f;
467 for (size_t i = 0; i < a.size(); ++i) {
468 dot += a[i] * b[i];
469 }
470
471 return dot;
472}
473
474} // namespace lloyal::embedding
#define LLOYAL_LOG_DEBUG(...)
liblloyal - Common definitions and logging
Definition common.hpp:47
Helper Utilities.
void apply_l2_normalize(std::vector< float > &vec)
Apply L2 normalization to embedding vector (in-place)
std::vector< float > get_seq(llama_context *ctx, llama_seq_id seq, Normalize normalize=Normalize::L2)
Get embeddings for specific sequence.
float cosine_similarity(const std::vector< float > &a, const std::vector< float > &b)
Compute cosine similarity between two embedding vectors.
void encode(llama_context *ctx, const llama_token *tokens, int32_t n_tokens, int32_t n_batch)
Encode tokens for embedding extraction.
std::vector< float > get(llama_context *ctx, Normalize normalize=Normalize::L2)
Get embeddings for last decoded batch.
Normalize
Normalization modes for embedding vectors.
Definition embedding.hpp:47
int32_t pooling_type(llama_context *ctx)
Get pooling type for context.
std::vector< float > get_ith(llama_context *ctx, int32_t idx, Normalize normalize=Normalize::L2)
Get embeddings for specific token index in last batch.
bool has_embeddings(const llama_model *model)
Check if model supports embeddings.
Definition embedding.hpp:63
bool has_pooling(llama_context *ctx)
Check if context has pooling enabled.
Definition embedding.hpp:99
int32_t dimension(const llama_model *model)
Get embedding dimension for model.
Definition embedding.hpp:79
void batch_clear(llama_batch &batch)
Clear batch to empty state.
Definition helpers.hpp:64
void batch_add(llama_batch &batch, llama_token id, int32_t pos, const std::vector< llama_seq_id > &seq_ids, bool logits, int32_t capacity=-1)
Add single token to batch with position and sequence info.
Definition helpers.hpp:84
RAII guard for automatic batch cleanup Ensures llama_batch_free is called even if exceptions occur.