110 const std::vector<std::span<const llama_token>>& prompts,
111 std::vector<float*>& output,
114 if (!ctx)
throw std::runtime_error(
"logits::process_chunks - NULL context");
115 if (prompts.size() != output.size())
116 throw std::runtime_error(
"logits::process_chunks - prompts/output size mismatch");
117 if (prompts.empty())
return;
119 const int32_t seq_max =
static_cast<int32_t
>(llama_n_seq_max(ctx));
120 const int32_t batch_limit =
static_cast<int32_t
>(llama_n_batch(ctx));
121 const int32_t n =
static_cast<int32_t
>(prompts.size());
124 for (int32_t group_start = 0; group_start < n; group_start += seq_max) {
125 int32_t group_size = std::min(seq_max, n - group_start);
128 auto chunks =
decode::bin_pack(&prompts[group_start], group_size, batch_limit);
129 if (chunks.empty())
continue;
131 for (
const auto& chunk : chunks) {
132 if (chunk.oversized) {
133 int32_t gi = group_start + chunk.indices[0];
134 llama_seq_id seq =
static_cast<llama_seq_id
>(chunk.indices[0]);
137 static_cast<int32_t
>(prompts[gi].size()),
138 0, batch_limit, seq) != 0)
139 throw std::runtime_error(
"logits::process_chunks - decode::many failed");
141 std::memcpy(output[gi],
get(ctx, -1), n_vocab *
sizeof(
float));
146 std::vector<decode::ScatterItem> scatter_items(chunk.indices.size());
147 for (
size_t k = 0; k < chunk.indices.size(); ++k) {
148 int32_t gi = group_start + chunk.indices[k];
149 scatter_items[k].tokens = prompts[gi];
150 scatter_items[k].start_pos = 0;
151 scatter_items[k].seq_id =
static_cast<llama_seq_id
>(chunk.indices[k]);
152 scatter_items[k].output_logits =
true;
156 static_cast<int32_t
>(scatter_items.size()),
158 throw std::runtime_error(
"logits::process_chunks - decode::scatter failed");
162 for (
size_t k = 0; k < scatter_items.size(); ++k) {
163 int32_t gi = group_start + chunk.indices[k];
164 int32_t item_n =
static_cast<int32_t
>(scatter_items[k].tokens.size());
165 std::memcpy(output[gi],
get(ctx, cursor + item_n - 1), n_vocab *
sizeof(
float));
171 for (int32_t s = 0; s < group_size; ++s) {