liblloyal 1.0.0
Composable primitives for llama.cpp inference
Loading...
Searching...
No Matches
json-schema-to-grammar.hpp
Go to the documentation of this file.
1#pragma once
2
3// SPDX-License-Identifier: Apache-2.0
4// Copyright 2026 Lloyal Labs
5
6#include "common.hpp"
7#include "helpers.hpp" // For string_repeat, string_join, string_split
8#include <lloyal/nlohmann/json.hpp>
9
10#include <algorithm>
11#include <functional>
12#include <limits>
13#include <map>
14#include <regex>
15#include <sstream>
16#include <string>
17#include <unordered_map>
18#include <unordered_set>
19#include <vector>
20
33namespace lloyal {
34
35using json = nlohmann::ordered_json;
36
37// ===== PUBLIC API STRUCTS =====
38
40 std::function<std::string(const std::string &, const std::string &)> add_rule;
41 std::function<std::string(const std::string &, const json &)> add_schema;
42 std::function<void(json &)> resolve_refs;
43};
44
46 bool dotall = false;
47};
48
49// ===== PUBLIC API FUNCTIONS =====
50
58std::string json_schema_to_grammar(const json &schema, bool force_gbnf = false);
59
67std::string
68build_grammar(const std::function<void(const common_grammar_builder &)> &cb,
69 const common_grammar_options &options = {});
70
71} // namespace lloyal
72
73namespace lloyal::detail {
74
75// ===== CONSTANT TABLES =====
76
77inline constexpr const char *SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
78
80 std::string content;
81 std::vector<std::string> deps;
82};
83
91inline const std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
92 {"boolean", {"(\"true\" | \"false\") space", {}}},
93 {"decimal-part", {"[0-9]{1,16}", {}}},
94 {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
95 {"number",
96 {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? "
97 "integral-part)? space",
98 {"integral-part", "decimal-part"}}},
99 {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
100 {"value",
101 {"object | array | string | number | boolean | null",
102 {"object", "array", "string", "number", "boolean", "null"}}},
103 {"object",
104 {"\"{\" space ( string \":\" space value (\",\" space string \":\" space "
105 "value)* )? \"}\" space",
106 {"string", "value"}}},
107 {"array",
108 {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
109 {"uuid",
110 {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" "
111 "[0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space",
112 {}}},
113 {"char",
114 {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" "
115 "[0-9a-fA-F]{4})",
116 {}}},
117 {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
118 {"null", {"\"null\" space", {}}},
119};
120
128inline const std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES =
129 {{"date",
130 {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | "
131 "[1-2] [0-9] | \"3\" [0-1] )",
132 {}}},
133 {"time",
134 {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" "
135 "[0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) "
136 "\":\" [0-5] [0-9] )",
137 {}}},
138 {"date-time", {"date \"T\" time", {"date", "time"}}},
139 {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
140 {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
141 {"date-time-string",
142 {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}};
143
156inline bool is_reserved_name(const std::string &name) {
157 static std::unordered_set<std::string> RESERVED_NAMES;
158 if (RESERVED_NAMES.empty()) {
159 RESERVED_NAMES.insert("root");
160 for (const auto &p : PRIMITIVE_RULES)
161 RESERVED_NAMES.insert(p.first);
162 for (const auto &p : STRING_FORMAT_RULES)
163 RESERVED_NAMES.insert(p.first);
164 }
165 return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
166}
167
168// Regex patterns for escaping
169inline std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
170inline std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
171inline std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
172
173inline const std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
174 {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}};
175
176inline const std::unordered_set<char> NON_LITERAL_SET = {
177 '|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
178inline const std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {
179 '^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
180
181// ===== INTERNAL HELPER FUNCTIONS =====
182
183inline std::string build_repetition(const std::string &item_rule, int min_items,
184 int max_items,
185 const std::string &separator_rule = "") {
186 auto has_max = max_items != std::numeric_limits<int>::max();
187
188 if (max_items == 0) {
189 return "";
190 }
191 if (min_items == 0 && max_items == 1) {
192 return item_rule + "?";
193 }
194
195 if (separator_rule.empty()) {
196 if (min_items == 1 && !has_max) {
197 return item_rule + "+";
198 } else if (min_items == 0 && !has_max) {
199 return item_rule + "*";
200 } else {
201 return item_rule + "{" + std::to_string(min_items) + "," +
202 (has_max ? std::to_string(max_items) : "") + "}";
203 }
204 }
205
206 auto result = item_rule + " " +
207 build_repetition("(" + separator_rule + " " + item_rule + ")",
208 min_items == 0 ? 0 : min_items - 1,
209 has_max ? max_items - 1 : max_items);
210 if (min_items == 0) {
211 result = "(" + result + ")?";
212 }
213 return result;
214}
215
216inline void _build_min_max_int(int min_value, int max_value,
217 std::stringstream &out, int decimals_left = 16,
218 bool top_level = true) {
219 auto has_min = min_value != std::numeric_limits<int>::min();
220 auto has_max = max_value != std::numeric_limits<int>::max();
221
222 auto digit_range = [&](char from, char to) {
223 out << "[";
224 if (from == to) {
225 out << from;
226 } else {
227 out << from << "-" << to;
228 }
229 out << "]";
230 };
231 auto more_digits = [&](int min_digits, int max_digits) {
232 out << "[0-9]";
233 if (min_digits == max_digits && min_digits == 1) {
234 return;
235 }
236 out << "{";
237 out << min_digits;
238 if (max_digits != min_digits) {
239 out << ",";
240 if (max_digits != std::numeric_limits<int>::max()) {
241 out << max_digits;
242 }
243 }
244 out << "}";
245 };
246 std::function<void(const std::string_view &, const std::string_view &)>
247 uniform_range =
248 [&](const std::string_view &from, const std::string_view &to) {
249 size_t i = 0;
250 while (i < from.length() && i < to.length() && from[i] == to[i]) {
251 i++;
252 }
253 if (i > 0) {
254 out << "\"" << from.substr(0, i) << "\"";
255 }
256 if (i < from.length() && i < to.length()) {
257 if (i > 0) {
258 out << " ";
259 }
260 auto sub_len = from.length() - i - 1;
261 if (sub_len > 0) {
262 auto from_sub = from.substr(i + 1);
263 auto to_sub = to.substr(i + 1);
264 auto sub_zeros = lloyal::string_repeat("0", sub_len);
265 auto sub_nines = lloyal::string_repeat("9", sub_len);
266
267 auto to_reached = false;
268 out << "(";
269 if (from_sub == sub_zeros) {
270 digit_range(from[i], to[i] - 1);
271 out << " ";
272 more_digits(sub_len, sub_len);
273 } else {
274 out << "[" << from[i] << "] ";
275 out << "(";
276 uniform_range(from_sub, sub_nines);
277 out << ")";
278 if (from[i] < to[i] - 1) {
279 out << " | ";
280 if (to_sub == sub_nines) {
281 digit_range(from[i] + 1, to[i]);
282 to_reached = true;
283 } else {
284 digit_range(from[i] + 1, to[i] - 1);
285 }
286 out << " ";
287 more_digits(sub_len, sub_len);
288 }
289 }
290 if (!to_reached) {
291 out << " | ";
292 digit_range(to[i], to[i]);
293 out << " ";
294 uniform_range(sub_zeros, to_sub);
295 }
296 out << ")";
297 } else {
298 out << "[" << from[i] << "-" << to[i] << "]";
299 }
300 }
301 };
302
303 if (has_min && has_max) {
304 if (min_value < 0 && max_value < 0) {
305 out << "\"-\" (";
306 _build_min_max_int(-max_value, -min_value, out, decimals_left,
307 /* top_level= */ true);
308 out << ")";
309 return;
310 }
311
312 if (min_value < 0) {
313 out << "\"-\" (";
314 _build_min_max_int(0, -min_value, out, decimals_left,
315 /* top_level= */ true);
316 out << ") | ";
317 min_value = 0;
318 }
319
320 auto min_s = std::to_string(min_value);
321 auto max_s = std::to_string(max_value);
322 auto min_digits = min_s.length();
323 auto max_digits = max_s.length();
324
325 for (auto digits = min_digits; digits < max_digits; digits++) {
326 uniform_range(min_s, lloyal::string_repeat("9", digits));
327 min_s = "1" + lloyal::string_repeat("0", digits);
328 out << " | ";
329 }
330 uniform_range(min_s, max_s);
331 return;
332 }
333
334 auto less_decimals = std::max(decimals_left - 1, 1);
335
336 if (has_min) {
337 if (min_value < 0) {
338 out << "\"-\" (";
339 _build_min_max_int(std::numeric_limits<int>::min(), -min_value, out,
340 decimals_left, /* top_level= */ false);
341 out << ") | [0] | [1-9] ";
342 more_digits(0, decimals_left - 1);
343 } else if (min_value == 0) {
344 if (top_level) {
345 out << "[0] | [1-9] ";
346 more_digits(0, less_decimals);
347 } else {
348 more_digits(1, decimals_left);
349 }
350 } else if (min_value <= 9) {
351 char c = '0' + min_value;
352 auto range_start = top_level ? '1' : '0';
353 if (c > range_start) {
354 digit_range(range_start, c - 1);
355 out << " ";
356 more_digits(1, less_decimals);
357 out << " | ";
358 }
359 digit_range(c, '9');
360 out << " ";
361 more_digits(0, less_decimals);
362 } else {
363 auto min_s = std::to_string(min_value);
364 auto len = min_s.length();
365 auto c = min_s[0];
366
367 if (c > '1') {
368 digit_range(top_level ? '1' : '0', c - 1);
369 out << " ";
370 more_digits(len, less_decimals);
371 out << " | ";
372 }
373 digit_range(c, c);
374 out << " (";
375 _build_min_max_int(std::stoi(min_s.substr(1)),
376 std::numeric_limits<int>::max(), out, less_decimals,
377 /* top_level= */ false);
378 out << ")";
379 if (c < '9') {
380 out << " | ";
381 digit_range(c + 1, '9');
382 out << " ";
383 more_digits(len - 1, less_decimals);
384 }
385 }
386 return;
387 }
388
389 if (has_max) {
390 if (max_value >= 0) {
391 if (top_level) {
392 out << "\"-\" [1-9] ";
393 more_digits(0, less_decimals);
394 out << " | ";
395 }
396 _build_min_max_int(0, max_value, out, decimals_left,
397 /* top_level= */ true);
398 } else {
399 out << "\"-\" (";
400 _build_min_max_int(-max_value, std::numeric_limits<int>::max(), out,
401 decimals_left, /* top_level= */ false);
402 out << ")";
403 }
404 return;
405 }
406
407 throw std::runtime_error(
408 "At least one of min_value or max_value must be set");
409}
410
411inline std::string replacePattern(
412 const std::string &input, const std::regex &regex,
413 const std::function<std::string(const std::smatch &)> &replacement) {
414 std::smatch match;
415 std::string result;
416
417 std::string::const_iterator searchStart(input.cbegin());
418 std::string::const_iterator searchEnd(input.cend());
419
420 while (std::regex_search(searchStart, searchEnd, match, regex)) {
421 result.append(searchStart, searchStart + match.position());
422 result.append(replacement(match));
423 searchStart = match.suffix().first;
424 }
425
426 result.append(searchStart, searchEnd);
427
428 return result;
429}
430
431inline std::string format_literal(const std::string &literal) {
432 std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE,
433 [&](const std::smatch &match) {
434 char c = match.str()[0];
435 return GRAMMAR_LITERAL_ESCAPES.at(c);
436 });
437 return "\"" + escaped + "\"";
438}
439
440// Forward declare SchemaConverter for build_grammar
441class SchemaConverter;
442
443} // namespace lloyal::detail
444
445// Declare build_grammar here so SchemaConverter can be friended
446namespace lloyal {
447std::string
448build_grammar(const std::function<void(const common_grammar_builder &)> &cb,
449 const common_grammar_options &options);
450}
451
452namespace lloyal::detail {
453
454// ===== SCHEMA CONVERTER CLASS =====
455
457private:
458 friend std::string lloyal::build_grammar(
459 const std::function<void(const common_grammar_builder &)> &cb,
460 const common_grammar_options &options);
461
462 std::function<json(const std::string &)> _fetch_json;
463 bool _dotall;
464 std::map<std::string, std::string> _rules;
465 std::unordered_map<std::string, json> _refs;
466 std::unordered_set<std::string> _refs_being_resolved;
467 std::vector<std::string> _errors;
468 std::vector<std::string> _warnings;
469
470 std::string _add_rule(const std::string &name, const std::string &rule);
471 std::string _generate_union_rule(const std::string &name,
472 const std::vector<json> &alt_schemas);
473 std::string _visit_pattern(const std::string &pattern,
474 const std::string &name);
475 std::string _not_strings(const std::vector<std::string> &strings);
476 std::string _resolve_ref(const std::string &ref);
477 std::string _build_object_rule(
478 const std::vector<std::pair<std::string, json>> &properties,
479 const std::unordered_set<std::string> &required, const std::string &name,
480 const json &additional_properties);
481 std::string _add_primitive(const std::string &name, const BuiltinRule &rule);
482
483public:
485 const std::function<json(const std::string &)> &fetch_json, bool dotall)
486 : _fetch_json(fetch_json), _dotall(dotall) {
487 _rules["space"] = SPACE_RULE;
488 }
489
503 void resolve_refs(json &schema, const std::string &url);
504 std::string _generate_constant_rule(const json &value);
505
522 std::string visit(const json &schema, const std::string &name);
523 void check_errors();
524 std::string format_grammar();
525};
526
527// Due to the complexity and length of the implementation, I'll include the key
528// methods inline The full implementation follows the exact pattern from
529// json-schema-to-grammar.cpp
530
531inline std::string SchemaConverter::_add_rule(const std::string &name,
532 const std::string &rule) {
533 std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
534 if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
535 _rules[esc_name] = rule;
536 return esc_name;
537 } else {
538 int i = 0;
539 while (_rules.find(esc_name + std::to_string(i)) != _rules.end() &&
540 _rules[esc_name + std::to_string(i)] != rule) {
541 i++;
542 }
543 std::string key = esc_name + std::to_string(i);
544 _rules[key] = rule;
545 return key;
546 }
547}
548
549inline std::string
550SchemaConverter::_generate_union_rule(const std::string &name,
551 const std::vector<json> &alt_schemas) {
552 std::vector<std::string> rules;
553 for (size_t i = 0; i < alt_schemas.size(); i++) {
554 rules.push_back(
555 visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") +
556 std::to_string(i)));
557 }
558 return lloyal::string_join(rules, " | ");
559}
560
561// The remaining methods follow the exact implementation from the source file...
562// Due to length constraints, I'm including the essential structure.
563// The full ~1000 line implementation should be copied from
564// json-schema-to-grammar.cpp with the following conversions:
565// 1. All static functions → inline functions in detail namespace
566// 2. All member functions → inline member functions
567// 3. string_repeat/join/split → lloyal::string_repeat/join/split
568// 4. PRIMITIVE_RULES/STRING_FORMAT_RULES →
569// detail::PRIMITIVE_RULES/STRING_FORMAT_RULES
570
571// ===== MISSING METHOD IMPLEMENTATIONS =====
572
573inline std::string SchemaConverter::_visit_pattern(const std::string &pattern,
574 const std::string &name) {
575 if (!(pattern.front() == '^' && pattern.back() == '$')) {
576 _errors.push_back("Pattern must start with '^' and end with '$'");
577 return "";
578 }
579 std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
580 std::unordered_map<std::string, std::string> sub_rule_ids;
581
582 size_t i = 0;
583 size_t length = sub_pattern.length();
584
585 using literal_or_rule = std::pair<std::string, bool>;
586 auto to_rule = [&](const literal_or_rule &ls) {
587 auto is_literal = ls.second;
588 auto s = ls.first;
589 return is_literal ? "\"" + s + "\"" : s;
590 };
591 std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
592 size_t start = i;
593 std::vector<literal_or_rule> seq;
594
595 auto get_dot = [&]() {
596 std::string rule;
597 if (_dotall) {
598 rule = "[\\U00000000-\\U0010FFFF]";
599 } else {
600 rule = "[^\\x0A\\x0D]";
601 }
602 return _add_rule("dot", rule);
603 };
604
605 // Joins the sequence, merging consecutive literals together.
606 auto join_seq = [&]() {
607 std::vector<literal_or_rule> ret;
608
609 std::string literal;
610 auto flush_literal = [&]() {
611 if (literal.empty()) {
612 return false;
613 }
614 ret.emplace_back(literal, true);
615 literal.clear();
616 return true;
617 };
618
619 for (const auto &item : seq) {
620 auto is_literal = item.second;
621 if (is_literal) {
622 literal += item.first;
623 } else {
624 flush_literal();
625 ret.push_back(item);
626 }
627 }
628 flush_literal();
629
630 std::vector<std::string> results;
631 for (const auto &item : ret) {
632 results.push_back(to_rule(item));
633 }
634 return std::make_pair(lloyal::string_join(results, " "), false);
635 };
636
637 while (i < length) {
638 char c = sub_pattern[i];
639 if (c == '.') {
640 seq.emplace_back(get_dot(), false);
641 i++;
642 } else if (c == '(') {
643 i++;
644 if (i < length) {
645 if (sub_pattern[i] == '?') {
646 _warnings.push_back("Unsupported pattern syntax");
647 }
648 }
649 seq.emplace_back("(" + to_rule(transform()) + ")", false);
650 } else if (c == ')') {
651 i++;
652 if (start > 0 && sub_pattern[start - 1] != '(') {
653 _errors.push_back("Unbalanced parentheses");
654 }
655 return join_seq();
656 } else if (c == '[') {
657 std::string square_brackets = std::string(1, c);
658 i++;
659 while (i < length && sub_pattern[i] != ']') {
660 if (sub_pattern[i] == '\\') {
661 square_brackets += sub_pattern.substr(i, 2);
662 i += 2;
663 } else {
664 square_brackets += sub_pattern[i];
665 i++;
666 }
667 }
668 if (i >= length) {
669 _errors.push_back("Unbalanced square brackets");
670 }
671 square_brackets += ']';
672 i++;
673 seq.emplace_back(square_brackets, false);
674 } else if (c == '|') {
675 seq.emplace_back("|", false);
676 i++;
677 } else if (c == '*' || c == '+' || c == '?') {
678 seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
679 i++;
680 } else if (c == '{') {
681 std::string curly_brackets = std::string(1, c);
682 i++;
683 while (i < length && sub_pattern[i] != '}') {
684 curly_brackets += sub_pattern[i];
685 i++;
686 }
687 if (i >= length) {
688 _errors.push_back("Unbalanced curly brackets");
689 }
690 curly_brackets += '}';
691 i++;
692 auto nums = lloyal::string_split(
693 curly_brackets.substr(1, curly_brackets.length() - 2), ",");
694 int min_times = 0;
695 int max_times = std::numeric_limits<int>::max();
696 try {
697 if (nums.size() == 1) {
698 min_times = max_times = std::stoi(nums[0]);
699 } else if (nums.size() != 2) {
700 _errors.push_back("Wrong number of values in curly brackets");
701 } else {
702 if (!nums[0].empty()) {
703 min_times = std::stoi(nums[0]);
704 }
705 if (!nums[1].empty()) {
706 max_times = std::stoi(nums[1]);
707 }
708 }
709 } catch (const std::invalid_argument &e) {
710 _errors.push_back("Invalid number in curly brackets");
711 return std::make_pair("", false);
712 }
713 auto &last = seq.back();
714 auto &sub = last.first;
715 auto sub_is_literal = last.second;
716
717 if (!sub_is_literal) {
718 std::string &sub_id = sub_rule_ids[sub];
719 if (sub_id.empty()) {
720 sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()),
721 sub);
722 }
723 sub = sub_id;
724 }
725 seq.back().first = build_repetition(
726 sub_is_literal ? "\"" + sub + "\"" : sub, min_times, max_times, "");
727 seq.back().second = false;
728 } else {
729 std::string literal;
730 auto is_non_literal = [&](char c) {
731 return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
732 };
733 while (i < length) {
734 if (sub_pattern[i] == '\\' && i < length - 1) {
735 char next = sub_pattern[i + 1];
738 i++;
739 literal += sub_pattern[i];
740 i++;
741 } else {
742 literal += sub_pattern.substr(i, 2);
743 i += 2;
744 }
745 } else if (sub_pattern[i] == '"') {
746 literal += "\\\"";
747 i++;
748 } else if (!is_non_literal(sub_pattern[i]) &&
749 (i == length - 1 || literal.empty() ||
750 sub_pattern[i + 1] == '.' ||
751 !is_non_literal(sub_pattern[i + 1]))) {
752 literal += sub_pattern[i];
753 i++;
754 } else {
755 break;
756 }
757 }
758 if (!literal.empty()) {
759 seq.emplace_back(literal, true);
760 }
761 }
762 }
763 return join_seq();
764 };
765 return _add_rule(name,
766 "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
767}
768
769inline std::string
770SchemaConverter::_not_strings(const std::vector<std::string> &strings) {
771 struct TrieNode {
772 std::map<char, TrieNode> children;
773 bool is_end_of_string;
774
775 TrieNode() : is_end_of_string(false) {}
776
777 void insert(const std::string &string) {
778 auto node = this;
779 for (char c : string) {
780 node = &node->children[c];
781 }
782 node->is_end_of_string = true;
783 }
784 };
785
786 TrieNode trie;
787 for (const auto &s : strings) {
788 trie.insert(s);
789 }
790
791 std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
792 std::ostringstream out;
793 out << "[\"] ( ";
794 std::function<void(const TrieNode &)> visit = [&](const TrieNode &node) {
795 std::ostringstream rejects;
796 auto first = true;
797 for (const auto &kv : node.children) {
798 rejects << kv.first;
799 if (first) {
800 first = false;
801 } else {
802 out << " | ";
803 }
804 out << "[" << kv.first << "]";
805 if (!kv.second.children.empty()) {
806 out << " (";
807 visit(kv.second);
808 out << ")";
809 } else if (kv.second.is_end_of_string) {
810 out << " " << char_rule << "+";
811 }
812 }
813 if (!node.children.empty()) {
814 if (!first) {
815 out << " | ";
816 }
817 out << "[^\"" << rejects.str() << "] " << char_rule << "*";
818 }
819 };
820 visit(trie);
821
822 out << " )";
823 if (!trie.is_end_of_string) {
824 out << "?";
825 }
826 out << " [\"] space";
827 return out.str();
828}
829
830inline std::string SchemaConverter::_resolve_ref(const std::string &ref) {
831 std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
832 if (_rules.find(ref_name) == _rules.end() &&
833 _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
834 _refs_being_resolved.insert(ref);
835 json resolved = _refs[ref];
836 ref_name = visit(resolved, ref_name);
837 _refs_being_resolved.erase(ref);
838 }
839 return ref_name;
840}
841
842inline std::string SchemaConverter::_build_object_rule(
843 const std::vector<std::pair<std::string, json>> &properties,
844 const std::unordered_set<std::string> &required, const std::string &name,
845 const json &additional_properties) {
846 std::vector<std::string> required_props;
847 std::vector<std::string> optional_props;
848 std::unordered_map<std::string, std::string> prop_kv_rule_names;
849 std::vector<std::string> prop_names;
850 for (const auto &kv : properties) {
851 const auto &prop_name = kv.first;
852 const auto &prop_schema = kv.second;
853
854 std::string prop_rule_name =
855 visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
856 prop_kv_rule_names[prop_name] =
857 _add_rule(name + (name.empty() ? "" : "-") + prop_name + "-kv",
858 format_literal(json(prop_name).dump()) +
859 " space \":\" space " + prop_rule_name);
860 if (required.find(prop_name) != required.end()) {
861 required_props.push_back(prop_name);
862 } else {
863 optional_props.push_back(prop_name);
864 }
865 prop_names.push_back(prop_name);
866 }
867 if ((additional_properties.is_boolean() &&
868 additional_properties.get<bool>()) ||
869 additional_properties.is_object()) {
870 std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
871 std::string value_rule =
872 additional_properties.is_object()
873 ? visit(additional_properties, sub_name + "-value")
874 : _add_primitive("value", PRIMITIVE_RULES.at("value"));
875
876 auto key_rule = prop_names.empty()
877 ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
878 : _add_rule(sub_name + "-k", _not_strings(prop_names));
879 std::string kv_rule =
880 _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
881 prop_kv_rule_names["*"] = kv_rule;
882 optional_props.push_back("*");
883 }
884
885 std::string rule = "\"{\" space ";
886 for (size_t i = 0; i < required_props.size(); i++) {
887 if (i > 0) {
888 rule += " \",\" space ";
889 }
890 rule += prop_kv_rule_names[required_props[i]];
891 }
892
893 if (!optional_props.empty()) {
894 rule += " (";
895 if (!required_props.empty()) {
896 rule += " \",\" space ( ";
897 }
898
899 std::function<std::string(const std::vector<std::string> &, bool)>
900 get_recursive_refs = [&](const std::vector<std::string> &ks,
901 bool first_is_optional) {
902 std::string res;
903 if (ks.empty()) {
904 return res;
905 }
906 std::string k = ks[0];
907 std::string kv_rule_name = prop_kv_rule_names[k];
908 std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
909 if (first_is_optional) {
910 res = comma_ref + (k == "*" ? "*" : "?");
911 } else {
912 res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
913 }
914 if (ks.size() > 1) {
915 res += " " +
916 _add_rule(name + (name.empty() ? "" : "-") + k + "-rest",
917 get_recursive_refs(std::vector<std::string>(
918 ks.begin() + 1, ks.end()),
919 true));
920 }
921 return res;
922 };
923
924 for (size_t i = 0; i < optional_props.size(); i++) {
925 if (i > 0) {
926 rule += " | ";
927 }
928 rule += get_recursive_refs(
929 std::vector<std::string>(optional_props.begin() + i,
930 optional_props.end()),
931 false);
932 }
933 if (!required_props.empty()) {
934 rule += " )";
935 }
936 rule += " )?";
937 }
938
939 rule += " \"}\" space";
940
941 return rule;
942}
943
944inline std::string SchemaConverter::_add_primitive(const std::string &name,
945 const BuiltinRule &rule) {
946 auto n = _add_rule(name, rule.content);
947 for (const auto &dep : rule.deps) {
948 BuiltinRule dep_rule;
949 auto it = PRIMITIVE_RULES.find(dep);
950 if (it == PRIMITIVE_RULES.end()) {
951 it = STRING_FORMAT_RULES.find(dep);
952 if (it == STRING_FORMAT_RULES.end()) {
953 _errors.push_back("Rule " + dep + " not known");
954 continue;
955 }
956 }
957 if (_rules.find(dep) == _rules.end()) {
958 _add_primitive(dep, it->second);
959 }
960 }
961 return n;
962}
963
965 const std::string &url) {
966 /*
967 * Resolves all $ref fields in the given schema, fetching any remote schemas,
968 * replacing each $ref with absolute reference URL and populates _refs with
969 * the respective referenced (sub)schema dictionaries.
970 */
971 std::function<void(json &)> visit_refs = [&](json &n) {
972 if (n.is_array()) {
973 for (auto &x : n) {
974 visit_refs(x);
975 }
976 } else if (n.is_object()) {
977 if (n.contains("$ref")) {
978 std::string ref = n["$ref"];
979 if (_refs.find(ref) == _refs.end()) {
980 json target;
981 if (ref.find("https://") == 0) {
982 std::string base_url = ref.substr(0, ref.find('#'));
983 auto it = _refs.find(base_url);
984 if (it != _refs.end()) {
985 target = it->second;
986 } else {
987 // Fetch the referenced schema and resolve its refs
988 auto referenced = _fetch_json(ref);
989 resolve_refs(referenced, base_url);
990 _refs[base_url] = referenced;
991 }
992 if (ref.find('#') == std::string::npos ||
993 ref.substr(ref.find('#') + 1).empty()) {
994 return;
995 }
996 } else if (ref.find("#/") == 0) {
997 target = schema;
998 n["$ref"] = url + ref;
999 ref = url + ref;
1000 } else {
1001 _errors.push_back("Unsupported ref: " + ref);
1002 return;
1003 }
1004 std::string pointer = ref.substr(ref.find('#') + 1);
1005 std::vector<std::string> tokens = lloyal::string_split(pointer, "/");
1006 for (size_t i = 1; i < tokens.size(); ++i) {
1007 std::string sel = tokens[i];
1008 if (target.is_null() || !target.contains(sel)) {
1009 _errors.push_back("Error resolving ref " + ref + ": " + sel +
1010 " not in " + target.dump());
1011 return;
1012 }
1013 target = target[sel];
1014 }
1015 _refs[ref] = target;
1016 }
1017 } else {
1018 for (auto &kv : n.items()) {
1019 visit_refs(kv.value());
1020 }
1021 }
1022 }
1023 };
1024
1025 visit_refs(schema);
1026}
1027
1028inline std::string SchemaConverter::_generate_constant_rule(const json &value) {
1029 return format_literal(value.dump());
1030}
1031
1032inline std::string SchemaConverter::visit(const json &schema,
1033 const std::string &name) {
1034 json schema_type = schema.contains("type") ? schema["type"] : json();
1035 std::string schema_format =
1036 schema.contains("format") ? schema["format"].get<std::string>() : "";
1037 std::string rule_name = is_reserved_name(name) ? name + "-"
1038 : name.empty() ? "root"
1039 : name;
1040
1041 if (schema.contains("$ref")) {
1042 return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
1043 } else if (schema.contains("oneOf") || schema.contains("anyOf")) {
1044 std::vector<json> alt_schemas =
1045 schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>()
1046 : schema["anyOf"].get<std::vector<json>>();
1047 return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
1048 } else if (schema_type.is_array()) {
1049 std::vector<json> schema_types;
1050 for (const auto &t : schema_type) {
1051 json schema_copy(schema);
1052 schema_copy["type"] = t;
1053 schema_types.push_back(schema_copy);
1054 }
1055 return _add_rule(rule_name, _generate_union_rule(name, schema_types));
1056 } else if (schema.contains("const")) {
1057 return _add_rule(rule_name,
1058 _generate_constant_rule(schema["const"]) + " space");
1059 } else if (schema.contains("enum")) {
1060 std::vector<std::string> enum_values;
1061 for (const auto &v : schema["enum"]) {
1062 enum_values.push_back(_generate_constant_rule(v));
1063 }
1064 return _add_rule(rule_name,
1065 "(" + lloyal::string_join(enum_values, " | ") + ") space");
1066 } else if ((schema_type.is_null() || schema_type == "object") &&
1067 (schema.contains("properties") ||
1068 (schema.contains("additionalProperties") &&
1069 schema["additionalProperties"] != true))) {
1070 std::unordered_set<std::string> required;
1071 if (schema.contains("required") && schema["required"].is_array()) {
1072 for (const auto &item : schema["required"]) {
1073 if (item.is_string()) {
1074 required.insert(item.get<std::string>());
1075 }
1076 }
1077 }
1078 std::vector<std::pair<std::string, json>> properties;
1079 if (schema.contains("properties")) {
1080 for (const auto &prop : schema["properties"].items()) {
1081 properties.emplace_back(prop.key(), prop.value());
1082 }
1083 }
1084 return _add_rule(rule_name,
1085 _build_object_rule(properties, required, name,
1086 schema.contains("additionalProperties")
1087 ? schema["additionalProperties"]
1088 : json()));
1089 } else if ((schema_type.is_null() || schema_type == "object" ||
1090 schema_type == "string") &&
1091 schema.contains("allOf")) {
1092 std::unordered_set<std::string> required;
1093 std::vector<std::pair<std::string, json>> properties;
1094 std::map<std::string, size_t> enum_values;
1095 std::string hybrid_name = name;
1096 std::function<void(const json &, bool)> add_component =
1097 [&](const json &comp_schema, bool is_required) {
1098 if (comp_schema.contains("$ref")) {
1099 add_component(_refs[comp_schema["$ref"]], is_required);
1100 } else if (comp_schema.contains("properties")) {
1101 for (const auto &prop : comp_schema["properties"].items()) {
1102 properties.emplace_back(prop.key(), prop.value());
1103 if (is_required) {
1104 required.insert(prop.key());
1105 }
1106 }
1107 } else if (comp_schema.contains("enum")) {
1108 for (const auto &v : comp_schema["enum"]) {
1109 const auto rule = _generate_constant_rule(v);
1110 if (enum_values.find(rule) == enum_values.end()) {
1111 enum_values[rule] = 0;
1112 }
1113 enum_values[rule] += 1;
1114 }
1115 } else {
1116 // todo warning
1117 }
1118 };
1119 for (auto &t : schema["allOf"]) {
1120 if (t.contains("anyOf")) {
1121 for (auto &tt : t["anyOf"]) {
1122 add_component(tt, false);
1123 }
1124 } else {
1125 add_component(t, true);
1126 }
1127 }
1128 if (!enum_values.empty()) {
1129 std::vector<std::string> enum_intersection;
1130 for (const auto &p : enum_values) {
1131 if (p.second == schema["allOf"].size()) {
1132 enum_intersection.push_back(p.first);
1133 }
1134 }
1135 if (!enum_intersection.empty()) {
1136 return _add_rule(rule_name,
1137 "(" + lloyal::string_join(enum_intersection, " | ") +
1138 ") space");
1139 }
1140 }
1141 return _add_rule(rule_name, _build_object_rule(properties, required,
1142 hybrid_name, json()));
1143 } else if ((schema_type.is_null() || schema_type == "array") &&
1144 (schema.contains("items") || schema.contains("prefixItems"))) {
1145 json items =
1146 schema.contains("items") ? schema["items"] : schema["prefixItems"];
1147 if (items.is_array()) {
1148 std::string rule = "\"[\" space ";
1149 for (size_t i = 0; i < items.size(); i++) {
1150 if (i > 0) {
1151 rule += " \",\" space ";
1152 }
1153 rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" +
1154 std::to_string(i));
1155 }
1156 rule += " \"]\" space";
1157 return _add_rule(rule_name, rule);
1158 } else {
1159 std::string item_rule_name =
1160 visit(items, name + (name.empty() ? "" : "-") + "item");
1161 int min_items =
1162 schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
1163 json max_items_json =
1164 schema.contains("maxItems") ? schema["maxItems"] : json();
1165 int max_items = max_items_json.is_number_integer()
1166 ? max_items_json.get<int>()
1167 : std::numeric_limits<int>::max();
1168
1169 return _add_rule(rule_name,
1170 "\"[\" space " +
1171 build_repetition(item_rule_name, min_items,
1172 max_items, "\",\" space") +
1173 " \"]\" space");
1174 }
1175 } else if ((schema_type.is_null() || schema_type == "string") &&
1176 schema.contains("pattern")) {
1177 return _visit_pattern(schema["pattern"], rule_name);
1178 } else if ((schema_type.is_null() || schema_type == "string") &&
1179 std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
1180 return _add_primitive(rule_name == "root" ? "root" : schema_format,
1181 PRIMITIVE_RULES.at("uuid"));
1182 } else if ((schema_type.is_null() || schema_type == "string") &&
1183 STRING_FORMAT_RULES.find(schema_format + "-string") !=
1184 STRING_FORMAT_RULES.end()) {
1185 auto prim_name = schema_format + "-string";
1186 return _add_rule(
1187 rule_name,
1188 _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
1189 } else if (schema_type == "string" &&
1190 (schema.contains("minLength") || schema.contains("maxLength"))) {
1191 std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
1192 int min_len =
1193 schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
1194 int max_len = schema.contains("maxLength")
1195 ? schema["maxLength"].get<int>()
1196 : std::numeric_limits<int>::max();
1197 return _add_rule(
1198 rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) +
1199 " \"\\\"\" space");
1200 } else if (schema_type == "integer" &&
1201 (schema.contains("minimum") ||
1202 schema.contains("exclusiveMinimum") ||
1203 schema.contains("maximum") ||
1204 schema.contains("exclusiveMaximum"))) {
1205 int min_value = std::numeric_limits<int>::min();
1206 int max_value = std::numeric_limits<int>::max();
1207 if (schema.contains("minimum")) {
1208 min_value = schema["minimum"].get<int>();
1209 } else if (schema.contains("exclusiveMinimum")) {
1210 min_value = schema["exclusiveMinimum"].get<int>() + 1;
1211 }
1212 if (schema.contains("maximum")) {
1213 max_value = schema["maximum"].get<int>();
1214 } else if (schema.contains("exclusiveMaximum")) {
1215 max_value = schema["exclusiveMaximum"].get<int>() - 1;
1216 }
1217 std::stringstream out;
1218 out << "(";
1219 _build_min_max_int(min_value, max_value, out);
1220 out << ") space";
1221 return _add_rule(rule_name, out.str());
1222 } else if (schema.empty() || schema_type == "object") {
1223 return _add_rule(rule_name,
1224 _add_primitive("object", PRIMITIVE_RULES.at("object")));
1225 } else {
1226 if (!schema_type.is_string() ||
1227 PRIMITIVE_RULES.find(schema_type.get<std::string>()) ==
1228 PRIMITIVE_RULES.end()) {
1229 _errors.push_back("Unrecognized schema: " + schema.dump());
1230 return "";
1231 }
1232 // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at
1233 // least for zero
1234 return _add_primitive(rule_name == "root" ? "root"
1235 : schema_type.get<std::string>(),
1236 PRIMITIVE_RULES.at(schema_type.get<std::string>()));
1237 }
1238}
1239
1241 if (!_errors.empty()) {
1242 throw std::runtime_error("JSON schema conversion failed:\n" +
1243 lloyal::string_join(_errors, "\n"));
1244 }
1245 if (!_warnings.empty()) {
1246 fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n",
1247 lloyal::string_join(_warnings, "; ").c_str());
1248 }
1249}
1250
1252 std::stringstream ss;
1253 for (const auto &kv : _rules) {
1254 ss << kv.first << " ::= " << kv.second << std::endl;
1255 }
1256 return ss.str();
1257}
1258
1259} // namespace lloyal::detail
1260
1261namespace lloyal {
1262
1263// ===== PUBLIC API IMPLEMENTATION =====
1264
1265inline std::string json_schema_to_grammar(const json &schema, bool force_gbnf) {
1266#ifdef LLAMA_USE_LLGUIDANCE
1267 if (!force_gbnf) {
1268 return "%llguidance {}\nstart: %json " + schema.dump();
1269 }
1270#else
1271 (void)force_gbnf;
1272#endif // LLAMA_USE_LLGUIDANCE
1273 return build_grammar([&](const common_grammar_builder &callbacks) {
1274 auto copy = schema;
1275 callbacks.resolve_refs(copy);
1276 callbacks.add_schema("", copy);
1277 });
1278}
1279
1280inline std::string
1281build_grammar(const std::function<void(const common_grammar_builder &)> &cb,
1282 const common_grammar_options &options) {
1283 detail::SchemaConverter converter([&](const std::string &) { return json(); },
1284 options.dotall);
1285 common_grammar_builder builder{
1286 /* .add_rule = */ [&](const std::string &name, const std::string &rule) {
1287 return converter._add_rule(name, rule);
1288 },
1289 /* .add_schema = */
1290 [&](const std::string &name, const nlohmann::ordered_json &schema) {
1291 return converter.visit(schema, name == "root" ? "" : name);
1292 },
1293 /* .resolve_refs = */
1294 [&](nlohmann::ordered_json &schema) {
1295 converter.resolve_refs(schema, "");
1296 }};
1297 cb(builder);
1298 converter.check_errors();
1299 return converter.format_grammar();
1300}
1301
1302} // namespace lloyal
nlohmann::ordered_json json
SchemaConverter(const std::function< json(const std::string &)> &fetch_json, bool dotall)
std::string visit(const json &schema, const std::string &name)
Convert schema node to GBNF rule.
std::string _generate_constant_rule(const json &value)
void resolve_refs(json &schema, const std::string &url)
Resolve $ref pointers in JSON schema.
Helper Utilities.
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]")
const std::unordered_map< std::string, BuiltinRule > STRING_FORMAT_RULES
Grammar rules for string format validation.
const std::unordered_set< char > NON_LITERAL_SET
std::string replacePattern(const std::string &input, const std::regex &regex, const std::function< std::string(const std::smatch &)> &replacement)
void _build_min_max_int(int min_value, int max_value, std::stringstream &out, int decimals_left=16, bool top_level=true)
bool is_reserved_name(const std::string &name)
Check if name conflicts with GBNF reserved keywords.
const std::unordered_set< char > ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS
std::string format_literal(const std::string &literal)
const std::unordered_map< std::string, BuiltinRule > PRIMITIVE_RULES
Built-in grammar rules for JSON primitives.
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+")
std::string build_repetition(const std::string &item_rule, int min_items, int max_items, const std::string &separator_rule="")
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]")
constexpr const char * SPACE_RULE
const std::unordered_map< char, std::string > GRAMMAR_LITERAL_ESCAPES
JSON Schema to Grammar Converter (Header-Only)
std::string string_repeat(const std::string &str, size_t n)
Definition helpers.hpp:426
std::string json_schema_to_grammar(const json &schema, bool force_gbnf=false)
Convert JSON schema to GBNF grammar.
std::string string_join(const std::vector< std::string > &values, const std::string &separator)
Definition helpers.hpp:442
std::string build_grammar(const std::function< void(const common_grammar_builder &)> &cb, const common_grammar_options &options={})
Build grammar from callback.
std::vector< std::string > string_split(const std::string &str, const std::string &delimiter)
Definition helpers.hpp:455
nlohmann::ordered_json json
Definition helpers.hpp:50
Definition minja.hpp:575
std::function< std::string(const std::string &, const std::string &)> add_rule
std::function< void(json &)> resolve_refs
std::function< std::string(const std::string &, const json &)> add_schema
std::vector< std::string > deps