liblloyal 1.0.0
Composable primitives for llama.cpp inference
Loading...
Searching...
No Matches
chat-template.hpp
Go to the documentation of this file.
1/*
2 Copyright 2024 Google LLC
3
4 Use of this source code is governed by an MIT-style
5 license that can be found in the LICENSE file or at
6 https://opensource.org/licenses/MIT.
7*/
8// SPDX-License-Identifier: MIT
9#pragma once
10
11#include "minja.hpp"
12
13#include <chrono>
14#include <cstddef>
15#include <cstdio>
16#include <ctime>
17#include <exception>
18#include <iomanip>
19#include <memory>
20#include <sstream>
21#include <stdexcept>
22#include <string>
23#include <vector>
24
25#include <lloyal/nlohmann/json.hpp>
26
27using json = nlohmann::ordered_json;
28
29namespace minja {
30
32 bool supports_tools = false;
33 bool supports_tool_calls = false;
38 // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
39 // Most other templates (and OpenAI's API) expect the arguments object to be
40 // stringified.
42 // CohereForAI/c4ai-command-r-plus simple variant
44 // MiniMaxAI/MiniMax-Text-01 special
46};
47
49 nlohmann::ordered_json messages;
50 nlohmann::ordered_json tools;
52 nlohmann::ordered_json extra_context;
53 std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
54};
55
70
72
73private:
75 std::string source_;
76 std::string bos_token_;
77 std::string eos_token_;
78 std::shared_ptr<minja::TemplateNode> template_root_;
79 std::string tool_call_example_;
80
81 std::string try_raw_render(const nlohmann::ordered_json &messages,
82 const nlohmann::ordered_json &tools,
83 bool add_generation_prompt,
84 const nlohmann::ordered_json &extra_context =
85 nlohmann::ordered_json()) const {
86 try {
88 inputs.messages = messages;
89 inputs.tools = tools;
90 inputs.add_generation_prompt = add_generation_prompt;
91 inputs.extra_context = extra_context;
92 // Use fixed date for tests
93 inputs.now = std::chrono::system_clock::from_time_t(0);
94
96 opts.apply_polyfills = false;
97
98 auto prompt = apply(inputs, opts);
99 // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
100 return prompt;
101 } catch (const std::exception &e) {
102 // fprintf(stderr, "try_raw_render error: %s\n", e.what());
103 return "";
104 }
105 }
106
107public:
108 chat_template(const std::string &source, const std::string &bos_token,
109 const std::string &eos_token)
110 : source_(source), bos_token_(bos_token), eos_token_(eos_token) {
111 template_root_ =
112 minja::Parser::parse(source_, {
113 /* .trim_blocks = */ true,
114 /* .lstrip_blocks = */ true,
115 /* .keep_trailing_newline = */ false,
116 });
117
118 auto contains = [](const std::string &haystack, const std::string &needle) {
119 return haystack.find(needle) != std::string::npos;
120 };
121
122 const std::string user_needle = "<User Needle>";
123 const std::string sys_needle = "<System Needle>";
124 const json dummy_str_user_msg = {{"role", "user"},
125 {"content", user_needle}};
126 const json dummy_typed_user_msg = {
127 {"role", "user"},
128 {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
129
131 !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false),
132 user_needle) &&
133 contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false),
134 user_needle);
135
136 const auto dummy_user_msg = caps_.requires_typed_content
137 ? dummy_typed_user_msg
138 : dummy_str_user_msg;
139 const json needle_system_msg = {
140 {"role", "system"},
141 {"content",
143 ? json::array({{{"type", "text"}, {"text", sys_needle}}})
144 : json(sys_needle)},
145 };
146
147 caps_.supports_system_role = contains(try_raw_render(
148 {
149 needle_system_msg,
150 dummy_user_msg,
151 },
152 {}, false),
153 sys_needle);
154
155 auto out = try_raw_render(
156 json::array({dummy_user_msg}),
157 json::array({
158 {
159 {"name", "some_tool"},
160 {"type", "function"},
161 {"function",
162 {
163 {"name", "some_tool"},
164 {"description", "Some tool."},
165 {"parameters",
166 {
167 {"type", "object"},
168 {"properties",
169 {
170 {"arg",
171 {
172 {"type", "string"},
173 {"description", "Some argument."},
174 }},
175 }},
176 {"required", json::array({"arg"})},
177 }},
178 }},
179 },
180 }),
181 false);
182 caps_.supports_tools = contains(out, "some_tool");
183
184 const auto render_with_content = [&](const json &content) {
185 const json assistant_msg{{"role", "assistant"}, {"content", content}};
186 // Render two assistant messages as some templates like QwQ-32B are
187 // handling the content differently depending on whether it's the last
188 // message or not (to remove the <think> tag in all but the last message).
189 return try_raw_render(json::array({dummy_user_msg, assistant_msg,
190 dummy_user_msg, assistant_msg}),
191 {}, false);
192 };
193 auto out_empty = render_with_content("");
194 auto out_null = render_with_content(json());
196 contains(out_empty, user_needle) && !contains(out_null, user_needle);
197
198 json j_null;
199 auto make_tool_calls_msg = [&](const json &tool_calls) {
200 return json{
201 {"role", "assistant"},
202 {"content", caps_.requires_non_null_content ? "" : j_null},
203 {"tool_calls", tool_calls},
204 };
205 };
206 auto make_tool_call = [](const std::string &tool_name,
207 const json &arguments) {
208 return json{
209 {"id", "call_1___"},
210 {"type", "function"},
211 {"function",
212 {
213 {"arguments", arguments},
214 {"name", tool_name},
215 }},
216 };
217 };
218 const json dummy_args_obj{{"argument_needle", "print('Hello, World!')"}};
219
220 // Note: the arguments are rendered in both cases, but may be
221 // double-escaped, which we don't want.
222 out = try_raw_render(json::array({
223 dummy_user_msg,
224 make_tool_calls_msg(json::array({make_tool_call(
225 "ipython", dummy_args_obj.dump())})),
226 }),
227 {}, false);
228 auto tool_call_renders_str_arguments =
229 contains(out, "<parameter=argument_needle>") ||
230 contains(out, "\"argument_needle\":") ||
231 contains(out, "'argument_needle':");
232 out = try_raw_render(json::array({
233 dummy_user_msg,
234 make_tool_calls_msg(json::array(
235 {make_tool_call("ipython", dummy_args_obj)})),
236 }),
237 {}, false);
238 auto tool_call_renders_obj_arguments =
239 contains(out, "<parameter=argument_needle>") ||
240 contains(out, "\"argument_needle\":") ||
241 contains(out, "'argument_needle':");
242
243 caps_.supports_tool_calls =
244 tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
246 !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
247
248 if (caps_.supports_tool_calls) {
249 auto dummy_args = caps_.requires_object_arguments
250 ? dummy_args_obj
251 : json(dummy_args_obj.dump());
252 auto tc1 = make_tool_call("test_tool1", dummy_args);
253 auto tc2 = make_tool_call("test_tool2", dummy_args);
254 auto out =
255 try_raw_render(json::array({
256 dummy_user_msg,
257 make_tool_calls_msg(json::array({tc1, tc2})),
258 }),
259 {}, false);
261 contains(out, "test_tool1") && contains(out, "test_tool2");
262
263 out = try_raw_render(json::array({dummy_user_msg,
264 make_tool_calls_msg(json::array({tc1})),
265 {
266 {"role", "tool"},
267 {"name", "test_tool1"},
268 {"content", "Some response!"},
269 {"tool_call_id", "call_911_"},
270 }}),
271 {}, false);
272 caps_.supports_tool_responses = contains(out, "Some response!");
273 caps_.supports_tool_call_id = contains(out, "call_911_");
274 }
275
276 try {
277 if (!caps_.supports_tools) {
278 const json user_msg{
279 {"role", "user"},
280 {"content", "Hey"},
281 };
282 const json args{
283 {"arg1", "some_value"},
284 };
285 const json tool_call_msg{
286 {"role", "assistant"},
287 {"content", caps_.requires_non_null_content ? "" : j_null},
288 {"tool_calls",
289 json::array({
290 {
291 // TODO: detect if requires numerical id or fixed length ==
292 // 6 like Nemo
293 {"id", "call_1___"},
294 {"type", "function"},
295 {"function",
296 {
297 {"name", "tool_name"},
298 {"arguments", (caps_.requires_object_arguments
299 ? args
300 : json(minja::Value(args).dump(
301 -1, /* to_json= */ true)))},
302 }},
303 },
304 })},
305 };
306 std::string prefix, full;
307 {
309 inputs.messages = json::array({user_msg});
310 inputs.add_generation_prompt = true;
311 prefix = apply(inputs);
312 }
313 {
315 inputs.messages = json::array({user_msg, tool_call_msg});
316 inputs.add_generation_prompt = false;
317 full = apply(inputs);
318 }
319 auto eos_pos_last = full.rfind(eos_token_);
320 if (eos_pos_last == prefix.size() - eos_token_.size() ||
321 (full[full.size() - 1] == '\n' &&
322 (eos_pos_last == full.size() - eos_token_.size() - 1))) {
323 full = full.substr(0, eos_pos_last);
324 }
325 size_t common_prefix_length = 0;
326 for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
327 if (prefix[i] != full[i]) {
328 break;
329 }
330 if (prefix[i] == '<') {
331 // DeepSeek R1's template (as of 20250209) adds a trailing <think>
332 // if add_generation_prompt, but it removes thinking tags for past
333 // messages. The prefix and full strings diverge at <think> vs.
334 // <|tool▁calls▁begin|>, we avoid consuming the leading <.
335 continue;
336 }
337 common_prefix_length = i + 1;
338 }
339 auto example = full.substr(common_prefix_length);
340 if (example.find("tool_name") == std::string::npos &&
341 example.find("some_value") == std::string::npos) {
342 fprintf(
343 stderr,
344 "Failed to infer a tool call example (possible template bug)\n");
345 } else {
346 tool_call_example_ = example;
347 }
348 }
349 } catch (const std::exception &e) {
350 fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
351 }
352 }
353
354 const std::string &source() const { return source_; }
355 const std::string &bos_token() const { return bos_token_; }
356 const std::string &eos_token() const { return eos_token_; }
357 const chat_template_caps &original_caps() const { return caps_; }
358
359 // Deprecated, please use the form with chat_template_inputs and
360 // chat_template_options
361 std::string
362 apply(const nlohmann::ordered_json &messages,
363 const nlohmann::ordered_json &tools, bool add_generation_prompt,
364 const nlohmann::ordered_json &extra_context = nlohmann::ordered_json(),
365 bool apply_polyfills = true) {
366 fprintf(stderr, "[%s] Deprecated!\n", __func__);
368 inputs.messages = messages;
369 inputs.tools = tools;
370 inputs.add_generation_prompt = add_generation_prompt;
371 inputs.extra_context = extra_context;
372 inputs.now = std::chrono::system_clock::now();
373
375 opts.apply_polyfills = apply_polyfills;
376
377 return apply(inputs, opts);
378 }
379
380 std::string
382 const chat_template_options &opts = chat_template_options()) const {
383 json actual_messages;
384
385 auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
386 auto has_tool_calls = false;
387 auto has_tool_responses = false;
388 auto has_string_content = false;
389 for (const auto &message : inputs.messages) {
390 if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
391 has_tool_calls = true;
392 }
393 if (message.contains("role") && message["role"] == "tool") {
394 has_tool_responses = true;
395 }
396 if (message.contains("content") && message["content"].is_string()) {
397 has_string_content = true;
398 }
399 }
400
401 auto polyfill_system_role =
402 opts.polyfill_system_role && !caps_.supports_system_role;
403 auto polyfill_tools =
404 opts.polyfill_tools && has_tools && !caps_.supports_tools;
405 auto polyfill_tool_call_example =
406 polyfill_tools && opts.polyfill_tool_call_examples;
407 auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls &&
408 !caps_.supports_tool_calls;
409 auto polyfill_tool_responses = opts.polyfill_tool_responses &&
410 has_tool_responses &&
412 auto polyfill_object_arguments = opts.polyfill_object_arguments &&
413 has_tool_calls &&
415 auto polyfill_typed_content = opts.polyfill_typed_content &&
416 has_string_content &&
418
419 auto needs_polyfills =
420 opts.apply_polyfills &&
421 (false || polyfill_system_role || polyfill_tools ||
422 polyfill_tool_calls || polyfill_tool_responses ||
423 polyfill_object_arguments || polyfill_typed_content);
424
425 if (needs_polyfills) {
426 actual_messages = json::array();
427
428 auto add_message = [&](const json &msg) {
429 if (polyfill_typed_content && msg.contains("content") &&
430 !msg.at("content").is_null() && msg.at("content").is_string()) {
431 actual_messages.push_back({
432 {"role", msg.at("role")},
433 {"content",
434 {{
435 {"type", "text"},
436 {"text", msg.at("content")},
437 }}},
438 });
439 } else {
440 actual_messages.push_back(msg);
441 }
442 };
443
444 std::string pending_system;
445 auto flush_sys = [&]() {
446 if (!pending_system.empty()) {
447 add_message({
448 {"role", "user"},
449 {"content", pending_system},
450 });
451 pending_system.clear();
452 }
453 };
454
455 json adjusted_messages;
456 if (polyfill_tools) {
457 adjusted_messages = add_system(
458 inputs.messages,
459 "You can call any of the following tools to satisfy the user's "
460 "requests: " +
461 minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
462 (!polyfill_tool_call_example || tool_call_example_.empty()
463 ? ""
464 : "\n\nExample tool call syntax:\n\n" +
465 tool_call_example_ + "\n\n"));
466 } else {
467 adjusted_messages = inputs.messages;
468 }
469
470 for (const auto &message_ : adjusted_messages) {
471 auto message = message_;
472 if (!message.contains("role") ||
473 (!message.contains("content") && !message.contains("tool_calls"))) {
474 throw std::runtime_error("message must have 'role' and one of "
475 "'content' or 'tool_calls' fields: " +
476 message.dump());
477 }
478 std::string role = message.at("role");
479
480 if (message.contains("tool_calls")) {
481 if (polyfill_object_arguments || polyfill_tool_calls) {
482 for (auto &tool_call : message.at("tool_calls")) {
483 if (tool_call["type"] == "function") {
484 auto &function = tool_call.at("function");
485 auto &arguments = function.at("arguments");
486 if (arguments.is_string()) {
487 try {
488 arguments = json::parse(arguments.get<std::string>());
489 } catch (const std::exception &ecvt) {
490 fprintf(stderr, "Failed to parse arguments: %s\n",
491 ecvt.what());
492 }
493 }
494 }
495 }
496 }
497 if (polyfill_tool_calls) {
498 auto tool_calls = json::array();
499 for (const auto &tool_call : message.at("tool_calls")) {
500 if (tool_call.at("type") != "function") {
501 continue;
502 }
503 const auto &function = tool_call.at("function");
504 auto tc = json{
505 {"name", function.at("name")},
506 {"arguments", function.at("arguments")},
507 };
508 if (tool_call.contains("id")) {
509 tc["id"] = tool_call["id"];
510 }
511 tool_calls.push_back(tc);
512 }
513 auto obj = json{
514 {"tool_calls", tool_calls},
515 };
516 if (message.contains("content")) {
517 auto content = message.at("content");
518 if (!content.is_null() && !content.empty()) {
519 obj["content"] = content;
520 }
521 }
522 message["content"] = obj.dump(2);
523 message.erase("tool_calls");
524 }
525 }
526 if (polyfill_tool_responses && role == "tool") {
527 message["role"] = "user";
528 auto obj = json{
529 {"tool_response", json::object()},
530 };
531 if (message.contains("name")) {
532 obj["tool_response"]["tool"] = message.at("name");
533 }
534 obj["tool_response"]["content"] = message.at("content");
535 if (message.contains("tool_call_id")) {
536 obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
537 }
538 message["content"] = obj.dump(2);
539 message.erase("name");
540 }
541
542 if (!message["content"].is_null() && polyfill_system_role) {
543 std::string content = message.at("content");
544 if (role == "system") {
545 if (!pending_system.empty())
546 pending_system += "\n";
547 pending_system += content;
548 continue;
549 } else {
550 if (role == "user") {
551 if (!pending_system.empty()) {
552 message["content"] =
553 pending_system + (content.empty() ? "" : "\n" + content);
554 pending_system.clear();
555 }
556 } else {
557 flush_sys();
558 }
559 }
560 }
561 add_message(message);
562 }
563 flush_sys();
564 } else {
565 actual_messages = inputs.messages;
566 }
567
568 auto context = minja::Context::make(json({
569 {"messages", actual_messages},
570 {"add_generation_prompt", inputs.add_generation_prompt},
571 }));
572 context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
573 context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
574 if (opts.define_strftime_now) {
575 auto now = inputs.now;
576 context->set(
577 "strftime_now",
578 Value::callable([now](const std::shared_ptr<minja::Context> &,
579 minja::ArgumentsValue &args) {
580 args.expectArgs("strftime_now", {1, 1}, {0, 0});
581 auto format = args.args[0].get<std::string>();
582
583 auto time = std::chrono::system_clock::to_time_t(now);
584 auto local_time = *std::localtime(&time);
585 std::ostringstream ss;
586 ss << std::put_time(&local_time, format.c_str());
587 return ss.str();
588 }));
589 }
590 if (!inputs.tools.is_null()) {
591 context->set("tools", minja::Value(inputs.tools));
592 }
593 if (!inputs.extra_context.is_null()) {
594 for (auto &kv : inputs.extra_context.items()) {
595 context->set(kv.key(), minja::Value(kv.value()));
596 }
597 }
598
599 auto ret = template_root_->render(context);
600 // fprintf(stderr, "actual_messages: %s\n",
601 // actual_messages.dump(2).c_str()); fprintf(stderr, "apply: %s\n\n",
602 // ret.c_str());
603 return ret;
604 }
605
606 static nlohmann::ordered_json
607 add_system(const nlohmann::ordered_json &messages,
608 const std::string &system_prompt) {
609 json messages_with_system = messages;
610
611 if (!messages_with_system.empty() &&
612 messages_with_system[0].at("role") == "system") {
613 std::string existing_system = messages_with_system.at(0).at("content");
614 messages_with_system[0] = json{
615 {"role", "system"},
616 {"content", existing_system + "\n\n" + system_prompt},
617 };
618 } else {
619 messages_with_system.insert(messages_with_system.begin(),
620 json{
621 {"role", "system"},
622 {"content", system_prompt},
623 });
624 }
625 return messages_with_system;
626 }
627};
628
629} // namespace minja
nlohmann::ordered_json json
static std::shared_ptr< Context > make(Value &&values, const std::shared_ptr< Context > &parent=builtins())
Definition minja.hpp:3005
static std::shared_ptr< TemplateNode > parse(const std::string &template_str, const Options &options)
Definition minja.hpp:2616
static Value callable(const CallableType &callable)
Definition minja.hpp:202
chat_template(const std::string &source, const std::string &bos_token, const std::string &eos_token)
const chat_template_caps & original_caps() const
const std::string & eos_token() const
std::string apply(const nlohmann::ordered_json &messages, const nlohmann::ordered_json &tools, bool add_generation_prompt, const nlohmann::ordered_json &extra_context=nlohmann::ordered_json(), bool apply_polyfills=true)
const std::string & source() const
const std::string & bos_token() const
static nlohmann::ordered_json add_system(const nlohmann::ordered_json &messages, const std::string &system_prompt)
std::string apply(const chat_template_inputs &inputs, const chat_template_options &opts=chat_template_options()) const
void expectArgs(const std::string &method_name, const std::pair< size_t, size_t > &pos_count, const std::pair< size_t, size_t > &kw_count)
Definition minja.hpp:534
std::vector< Value > args
Definition minja.hpp:513
std::chrono::system_clock::time_point now
nlohmann::ordered_json messages
nlohmann::ordered_json extra_context
nlohmann::ordered_json tools