mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Add llama grammar to ensure proper json output
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
#include <boost/json.hpp>
|
||||
#include <cctype>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
@@ -97,11 +96,16 @@ std::string ToChatPrompt(const llama_model* model,
|
||||
return combined_prompt;
|
||||
}
|
||||
|
||||
const std::array<llama_chat_message, 2> messages = {
|
||||
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
|
||||
const std::array<llama_chat_message, 2> messages = {{
|
||||
{.role = "system", .content = system_prompt.c_str()},
|
||||
{.role = "user", .content = user_prompt.c_str()},
|
||||
}};
|
||||
|
||||
constexpr std::size_t min_template_buffer_size = 1024;
|
||||
|
||||
std::vector<char> buffer(std::max<std::size_t>(
|
||||
1024, (system_prompt.size() + user_prompt.size()) * 4));
|
||||
min_template_buffer_size,
|
||||
(system_prompt.size() + user_prompt.size()) * 4));
|
||||
|
||||
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
|
||||
int32_t message_count) -> int32_t {
|
||||
@@ -113,11 +117,11 @@ std::string ToChatPrompt(const llama_model* model,
|
||||
return result;
|
||||
}
|
||||
|
||||
if (result >= static_cast<int32_t>(buffer.size())) {
|
||||
const auto buffer_size = static_cast<int32_t>(buffer.size());
|
||||
if (result >= buffer_size) {
|
||||
buffer.resize(static_cast<std::size_t>(result) + 1);
|
||||
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
|
||||
true, buffer.data(),
|
||||
static_cast<int32_t>(buffer.size()));
|
||||
true, buffer.data(), buffer_size);
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -136,8 +140,9 @@ std::string ToChatPrompt(const llama_model* model,
|
||||
|
||||
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
|
||||
// combine the system and user prompts into a single "user" message.
|
||||
const std::array<llama_chat_message, 1> fallback_msg = {
|
||||
{{"user", combined_prompt.c_str()}}};
|
||||
const std::array<llama_chat_message, 1> fallback_msg = {{
|
||||
{.role = "user", .content = combined_prompt.c_str()},
|
||||
}};
|
||||
|
||||
template_result = apply_template_with_resize(fallback_msg.data(), 1);
|
||||
|
||||
@@ -188,102 +193,17 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
||||
"LlamaGenerator: failed to decode sampled token piece");
|
||||
}
|
||||
|
||||
// Shared parser used by the public extractor and JSON validation.
|
||||
static bool ExtractLastJsonObject(const std::string& text,
|
||||
std::string& json_out) {
|
||||
// Remember where the most recent balanced object started.
|
||||
size_t start = std::string::npos;
|
||||
|
||||
// Track nested braces outside of quoted strings.
|
||||
int depth = 0;
|
||||
|
||||
// Track whether the scan is currently inside a quoted string.
|
||||
bool in_string = false;
|
||||
|
||||
// Track escape sequences so quotes inside strings are handled correctly.
|
||||
bool escaped = false;
|
||||
|
||||
// Record whether at least one complete object was found.
|
||||
bool found = false;
|
||||
|
||||
// Keep the latest complete object candidate.
|
||||
std::string candidate;
|
||||
|
||||
// Scan the input text one character at a time.
|
||||
for (size_t i = 0; i < text.size(); ++i) {
|
||||
// Inspect the current character.
|
||||
const char chr = text[i];
|
||||
|
||||
// Inside a string literal, only escapes and quotes affect state.
|
||||
if (in_string) {
|
||||
if (escaped) {
|
||||
// The current character was escaped, so clear the escape flag.
|
||||
escaped = false;
|
||||
} else if (chr == '\\') {
|
||||
// Mark the next character as escaped.
|
||||
escaped = true;
|
||||
} else if (chr == '"') {
|
||||
// Closing quote ends the string literal.
|
||||
in_string = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Opening quotes enter string mode.
|
||||
if (chr == '"') {
|
||||
in_string = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Opening braces begin or nest a JSON object.
|
||||
if (chr == '{') {
|
||||
if (depth == 0) {
|
||||
// Record the start of the outermost object.
|
||||
start = i;
|
||||
}
|
||||
|
||||
// Increase nesting depth for the active object.
|
||||
++depth;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Closing braces may complete an object.
|
||||
if (chr == '}') {
|
||||
if (depth == 0) {
|
||||
// Ignore stray closing braces.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Drop one level of nesting.
|
||||
--depth;
|
||||
if (depth == 0 && start != std::string::npos) {
|
||||
// Capture the latest complete object seen so far.
|
||||
candidate = text.substr(start, i - start + 1);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Return the captured object text to the caller.
|
||||
json_out = std::move(candidate);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||
std::string& name_out,
|
||||
std::string& description_out) {
|
||||
auto validate_object = [&](const boost::json::value& jv,
|
||||
auto validate_object = [&](const boost::json::value& json_value,
|
||||
std::string& error_out) -> bool {
|
||||
if (!jv.is_object()) {
|
||||
if (!json_value.is_object()) {
|
||||
error_out = "JSON root must be an object";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& obj = jv.get_object();
|
||||
const auto& obj = json_value.get_object();
|
||||
if (!obj.contains("name") || !obj.at("name").is_string()) {
|
||||
error_out = "JSON field 'name' is missing or not a string";
|
||||
return false;
|
||||
@@ -313,14 +233,15 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||
std::string name_lower = name_out;
|
||||
std::string description_lower = description_out;
|
||||
|
||||
std::transform(
|
||||
name_lower.begin(), name_lower.end(), name_lower.begin(),
|
||||
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||
std::ranges::transform(name_lower, name_lower.begin(),
|
||||
[](unsigned char character) {
|
||||
return static_cast<char>(std::tolower(character));
|
||||
});
|
||||
|
||||
std::transform(description_lower.begin(), description_lower.end(),
|
||||
description_lower.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
std::ranges::transform(description_lower, description_lower.begin(),
|
||||
[](unsigned char character) {
|
||||
return static_cast<char>(std::tolower(character));
|
||||
});
|
||||
|
||||
if (name_lower == "string" || description_lower == "string") {
|
||||
error_out = "JSON appears to be a schema placeholder, not content";
|
||||
@@ -331,41 +252,16 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||
return true;
|
||||
};
|
||||
|
||||
boost::system::error_code ec;
|
||||
boost::json::value jv = boost::json::parse(raw, ec);
|
||||
boost::system::error_code error_code;
|
||||
boost::json::value json_value = boost::json::parse(raw, error_code);
|
||||
std::string validation_error;
|
||||
if (ec) {
|
||||
std::string extracted;
|
||||
if (!ExtractLastJsonObject(raw, extracted)) {
|
||||
return "JSON parse error: " + ec.message();
|
||||
}
|
||||
|
||||
ec.clear();
|
||||
jv = boost::json::parse(extracted, ec);
|
||||
if (ec) {
|
||||
return "JSON parse error: " + ec.message();
|
||||
}
|
||||
|
||||
if (!validate_object(jv, validation_error)) {
|
||||
return validation_error;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
if (error_code) {
|
||||
return "JSON parse error: " + error_code.message();
|
||||
}
|
||||
|
||||
if (!validate_object(jv, validation_error)) {
|
||||
if (!validate_object(json_value, validation_error)) {
|
||||
return validation_error;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::string ExtractLastJsonObject(const std::string& text) {
|
||||
// Reuse the internal parser and return an empty string if none was found.
|
||||
std::string extracted;
|
||||
if (ExtractLastJsonObject(text, extracted)) {
|
||||
return extracted;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user