118 auto contains = [](
const std::string &haystack,
const std::string &needle) {
119 return haystack.find(needle) != std::string::npos;
122 const std::string user_needle =
"<User Needle>";
123 const std::string sys_needle =
"<System Needle>";
124 const json dummy_str_user_msg = {{
"role",
"user"},
125 {
"content", user_needle}};
126 const json dummy_typed_user_msg = {
128 {
"content", json::array({{{
"type",
"text"}, {
"text", user_needle}}})}};
131 !contains(try_raw_render(json::array({dummy_str_user_msg}), {},
false),
133 contains(try_raw_render(json::array({dummy_typed_user_msg}), {},
false),
137 ? dummy_typed_user_msg
138 : dummy_str_user_msg;
139 const json needle_system_msg = {
143 ? json::array({{{
"type",
"text"}, {
"text", sys_needle}}})
155 auto out = try_raw_render(
156 json::array({dummy_user_msg}),
159 {
"name",
"some_tool"},
160 {
"type",
"function"},
163 {
"name",
"some_tool"},
164 {
"description",
"Some tool."},
173 {
"description",
"Some argument."},
176 {
"required", json::array({
"arg"})},
184 const auto render_with_content = [&](
const json &content) {
185 const json assistant_msg{{
"role",
"assistant"}, {
"content", content}};
189 return try_raw_render(json::array({dummy_user_msg, assistant_msg,
190 dummy_user_msg, assistant_msg}),
193 auto out_empty = render_with_content(
"");
194 auto out_null = render_with_content(
json());
196 contains(out_empty, user_needle) && !contains(out_null, user_needle);
199 auto make_tool_calls_msg = [&](
const json &tool_calls) {
201 {
"role",
"assistant"},
203 {
"tool_calls", tool_calls},
206 auto make_tool_call = [](
const std::string &tool_name,
207 const json &arguments) {
210 {
"type",
"function"},
213 {
"arguments", arguments},
218 const json dummy_args_obj{{
"argument_needle",
"print('Hello, World!')"}};
222 out = try_raw_render(json::array({
224 make_tool_calls_msg(json::array({make_tool_call(
225 "ipython", dummy_args_obj.dump())})),
228 auto tool_call_renders_str_arguments =
229 contains(out,
"<parameter=argument_needle>") ||
230 contains(out,
"\"argument_needle\":") ||
231 contains(out,
"'argument_needle':");
232 out = try_raw_render(json::array({
234 make_tool_calls_msg(json::array(
235 {make_tool_call(
"ipython", dummy_args_obj)})),
238 auto tool_call_renders_obj_arguments =
239 contains(out,
"<parameter=argument_needle>") ||
240 contains(out,
"\"argument_needle\":") ||
241 contains(out,
"'argument_needle':");
244 tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
246 !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
251 :
json(dummy_args_obj.dump());
252 auto tc1 = make_tool_call(
"test_tool1", dummy_args);
253 auto tc2 = make_tool_call(
"test_tool2", dummy_args);
255 try_raw_render(json::array({
257 make_tool_calls_msg(json::array({tc1, tc2})),
261 contains(out,
"test_tool1") && contains(out,
"test_tool2");
263 out = try_raw_render(json::array({dummy_user_msg,
264 make_tool_calls_msg(json::array({tc1})),
267 {
"name",
"test_tool1"},
268 {
"content",
"Some response!"},
269 {
"tool_call_id",
"call_911_"},
283 {
"arg1",
"some_value"},
285 const json tool_call_msg{
286 {
"role",
"assistant"},
294 {
"type",
"function"},
297 {
"name",
"tool_name"},
306 std::string prefix, full;
309 inputs.
messages = json::array({user_msg});
311 prefix =
apply(inputs);
315 inputs.
messages = json::array({user_msg, tool_call_msg});
317 full =
apply(inputs);
319 auto eos_pos_last = full.rfind(eos_token_);
320 if (eos_pos_last == prefix.size() - eos_token_.size() ||
321 (full[full.size() - 1] ==
'\n' &&
322 (eos_pos_last == full.size() - eos_token_.size() - 1))) {
323 full = full.substr(0, eos_pos_last);
325 size_t common_prefix_length = 0;
326 for (
size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
327 if (prefix[i] != full[i]) {
330 if (prefix[i] ==
'<') {
337 common_prefix_length = i + 1;
339 auto example = full.substr(common_prefix_length);
340 if (example.find(
"tool_name") == std::string::npos &&
341 example.find(
"some_value") == std::string::npos) {
344 "Failed to infer a tool call example (possible template bug)\n");
346 tool_call_example_ = example;
349 }
catch (
const std::exception &e) {
350 fprintf(stderr,
"Failed to generate tool call example: %s\n", e.what());
383 json actual_messages;
385 auto has_tools = inputs.
tools.is_array() && !inputs.
tools.empty();
386 auto has_tool_calls =
false;
387 auto has_tool_responses =
false;
388 auto has_string_content =
false;
389 for (
const auto &message : inputs.
messages) {
390 if (message.contains(
"tool_calls") && !message[
"tool_calls"].is_null()) {
391 has_tool_calls =
true;
393 if (message.contains(
"role") && message[
"role"] ==
"tool") {
394 has_tool_responses =
true;
396 if (message.contains(
"content") && message[
"content"].is_string()) {
397 has_string_content =
true;
401 auto polyfill_system_role =
403 auto polyfill_tools =
405 auto polyfill_tool_call_example =
406 polyfill_tools && opts.polyfill_tool_call_examples;
407 auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls &&
409 auto polyfill_tool_responses = opts.polyfill_tool_responses &&
410 has_tool_responses &&
412 auto polyfill_object_arguments = opts.polyfill_object_arguments &&
415 auto polyfill_typed_content = opts.polyfill_typed_content &&
416 has_string_content &&
419 auto needs_polyfills =
420 opts.apply_polyfills &&
421 (
false || polyfill_system_role || polyfill_tools ||
422 polyfill_tool_calls || polyfill_tool_responses ||
423 polyfill_object_arguments || polyfill_typed_content);
425 if (needs_polyfills) {
426 actual_messages = json::array();
428 auto add_message = [&](
const json &msg) {
429 if (polyfill_typed_content && msg.contains(
"content") &&
430 !msg.at(
"content").is_null() && msg.at(
"content").is_string()) {
431 actual_messages.push_back({
432 {
"role", msg.at(
"role")},
436 {
"text", msg.at(
"content")},
440 actual_messages.push_back(msg);
444 std::string pending_system;
445 auto flush_sys = [&]() {
446 if (!pending_system.empty()) {
449 {
"content", pending_system},
451 pending_system.clear();
455 json adjusted_messages;
456 if (polyfill_tools) {
459 "You can call any of the following tools to satisfy the user's "
462 (!polyfill_tool_call_example || tool_call_example_.empty()
464 :
"\n\nExample tool call syntax:\n\n" +
465 tool_call_example_ +
"\n\n"));
467 adjusted_messages = inputs.
messages;
470 for (
const auto &message_ : adjusted_messages) {
471 auto message = message_;
472 if (!message.contains(
"role") ||
473 (!message.contains(
"content") && !message.contains(
"tool_calls"))) {
474 throw std::runtime_error(
"message must have 'role' and one of "
475 "'content' or 'tool_calls' fields: " +
478 std::string role = message.at(
"role");
480 if (message.contains(
"tool_calls")) {
481 if (polyfill_object_arguments || polyfill_tool_calls) {
482 for (
auto &tool_call : message.at(
"tool_calls")) {
483 if (tool_call[
"type"] ==
"function") {
484 auto &function = tool_call.at(
"function");
485 auto &arguments = function.at(
"arguments");
486 if (arguments.is_string()) {
488 arguments = json::parse(arguments.get<std::string>());
489 }
catch (
const std::exception &ecvt) {
490 fprintf(stderr,
"Failed to parse arguments: %s\n",
497 if (polyfill_tool_calls) {
498 auto tool_calls = json::array();
499 for (
const auto &tool_call : message.at(
"tool_calls")) {
500 if (tool_call.at(
"type") !=
"function") {
503 const auto &function = tool_call.at(
"function");
505 {
"name", function.at(
"name")},
506 {
"arguments", function.at(
"arguments")},
508 if (tool_call.contains(
"id")) {
509 tc[
"id"] = tool_call[
"id"];
511 tool_calls.push_back(tc);
514 {
"tool_calls", tool_calls},
516 if (message.contains(
"content")) {
517 auto content = message.at(
"content");
518 if (!content.is_null() && !content.empty()) {
519 obj[
"content"] = content;
522 message[
"content"] = obj.dump(2);
523 message.erase(
"tool_calls");
526 if (polyfill_tool_responses && role ==
"tool") {
527 message[
"role"] =
"user";
529 {
"tool_response", json::object()},
531 if (message.contains(
"name")) {
532 obj[
"tool_response"][
"tool"] = message.at(
"name");
534 obj[
"tool_response"][
"content"] = message.at(
"content");
535 if (message.contains(
"tool_call_id")) {
536 obj[
"tool_response"][
"tool_call_id"] = message.at(
"tool_call_id");
538 message[
"content"] = obj.dump(2);
539 message.erase(
"name");
542 if (!message[
"content"].is_null() && polyfill_system_role) {
543 std::string content = message.at(
"content");
544 if (role ==
"system") {
545 if (!pending_system.empty())
546 pending_system +=
"\n";
547 pending_system += content;
550 if (role ==
"user") {
551 if (!pending_system.empty()) {
553 pending_system + (content.empty() ?
"" :
"\n" + content);
554 pending_system.clear();
561 add_message(message);
569 {
"messages", actual_messages},
572 context->set(
"bos_token", opts.use_bos_token ? bos_token_ :
"");
573 context->set(
"eos_token", opts.use_eos_token ? eos_token_ :
"");
574 if (opts.define_strftime_now) {
575 auto now = inputs.
now;
580 args.
expectArgs(
"strftime_now", {1, 1}, {0, 0});
581 auto format = args.
args[0].get<std::string>();
583 auto time = std::chrono::system_clock::to_time_t(now);
584 auto local_time = *std::localtime(&time);
585 std::ostringstream ss;
586 ss << std::put_time(&local_time, format.c_str());
590 if (!inputs.
tools.is_null()) {
599 auto ret = template_root_->render(context);
608 const std::string &system_prompt) {
609 json messages_with_system = messages;
611 if (!messages_with_system.empty() &&
612 messages_with_system[0].at(
"role") ==
"system") {
613 std::string existing_system = messages_with_system.at(0).at(
"content");
614 messages_with_system[0] =
json{
616 {
"content", existing_system +
"\n\n" + system_prompt},
619 messages_with_system.insert(messages_with_system.begin(),
622 {
"content", system_prompt},
625 return messages_with_system;