Add llama grammar to ensure proper json output

This commit is contained in:
Aaron Po
2026-04-15 13:39:01 -04:00
parent ddf4bcb981
commit 62dfb5e14a
7 changed files with 115 additions and 231 deletions

View File

@@ -11,7 +11,6 @@
#include <boost/json.hpp>
#include <cctype>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
@@ -97,11 +96,16 @@ std::string ToChatPrompt(const llama_model* model,
return combined_prompt;
}
const std::array<llama_chat_message, 2> messages = {
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
const std::array<llama_chat_message, 2> messages = {{
{.role = "system", .content = system_prompt.c_str()},
{.role = "user", .content = user_prompt.c_str()},
}};
constexpr std::size_t min_template_buffer_size = 1024;
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
min_template_buffer_size,
(system_prompt.size() + user_prompt.size()) * 4));
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
@@ -113,11 +117,11 @@ std::string ToChatPrompt(const llama_model* model,
return result;
}
if (result >= static_cast<int32_t>(buffer.size())) {
const auto buffer_size = static_cast<int32_t>(buffer.size());
if (result >= buffer_size) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
true, buffer.data(),
static_cast<int32_t>(buffer.size()));
true, buffer.data(), buffer_size);
}
return result;
@@ -136,8 +140,9 @@ std::string ToChatPrompt(const llama_model* model,
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
const std::array<llama_chat_message, 1> fallback_msg = {{
{.role = "user", .content = combined_prompt.c_str()},
}};
template_result = apply_template_with_resize(fallback_msg.data(), 1);
@@ -188,102 +193,17 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
"LlamaGenerator: failed to decode sampled token piece");
}
// Shared parser used by the public extractor and JSON validation.
static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) {
// Remember where the most recent balanced object started.
size_t start = std::string::npos;
// Track nested braces outside of quoted strings.
int depth = 0;
// Track whether the scan is currently inside a quoted string.
bool in_string = false;
// Track escape sequences so quotes inside strings are handled correctly.
bool escaped = false;
// Record whether at least one complete object was found.
bool found = false;
// Keep the latest complete object candidate.
std::string candidate;
// Scan the input text one character at a time.
for (size_t i = 0; i < text.size(); ++i) {
// Inspect the current character.
const char chr = text[i];
// Inside a string literal, only escapes and quotes affect state.
if (in_string) {
if (escaped) {
// The current character was escaped, so clear the escape flag.
escaped = false;
} else if (chr == '\\') {
// Mark the next character as escaped.
escaped = true;
} else if (chr == '"') {
// Closing quote ends the string literal.
in_string = false;
}
continue;
}
// Opening quotes enter string mode.
if (chr == '"') {
in_string = true;
continue;
}
// Opening braces begin or nest a JSON object.
if (chr == '{') {
if (depth == 0) {
// Record the start of the outermost object.
start = i;
}
// Increase nesting depth for the active object.
++depth;
continue;
}
// Closing braces may complete an object.
if (chr == '}') {
if (depth == 0) {
// Ignore stray closing braces.
continue;
}
// Drop one level of nesting.
--depth;
if (depth == 0 && start != std::string::npos) {
// Capture the latest complete object seen so far.
candidate = text.substr(start, i - start + 1);
found = true;
}
}
}
if (!found) {
return false;
}
// Return the captured object text to the caller.
json_out = std::move(candidate);
return true;
}
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
auto validate_object = [&](const boost::json::value& json_value,
std::string& error_out) -> bool {
if (!jv.is_object()) {
if (!json_value.is_object()) {
error_out = "JSON root must be an object";
return false;
}
const auto& obj = jv.get_object();
const auto& obj = json_value.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
@@ -313,14 +233,15 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::ranges::transform(name_lower, name_lower.begin(),
[](unsigned char character) {
return static_cast<char>(std::tolower(character));
});
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
std::ranges::transform(description_lower, description_lower.begin(),
[](unsigned char character) {
return static_cast<char>(std::tolower(character));
});
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
@@ -331,41 +252,16 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
boost::system::error_code error_code;
boost::json::value json_value = boost::json::parse(raw, error_code);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
if (error_code) {
return "JSON parse error: " + error_code.message();
}
if (!validate_object(jv, validation_error)) {
if (!validate_object(json_value, validation_error)) {
return validation_error;
}
return std::nullopt;
}
std::string ExtractLastJsonObject(const std::string& text) {
// Reuse the internal parser and return an empty string if none was found.
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
return {};
}