Enhance ValidateBreweryJson to include reasoning output and update GenerateBrewery to use user_prompt

Add gemma parser
This commit is contained in:
Aaron Po
2026-04-16 20:06:36 -04:00
parent 44a74ed2ad
commit fcc7a5dc8b
12 changed files with 144 additions and 122 deletions

View File

@@ -107,6 +107,7 @@ set(SOURCES
src/data_generation/llama/infer.cc src/data_generation/llama/infer.cc
src/data_generation/llama/load.cc src/data_generation/llama/load.cc
src/data_generation/llama/load_brewery_prompt.cc src/data_generation/llama/load_brewery_prompt.cc
src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc
src/data_generation/mock/deterministic_hash.cc src/data_generation/mock/deterministic_hash.cc
src/data_generation/mock/generate_brewery.cc src/data_generation/mock/generate_brewery.cc
src/data_generation/mock/generate_user.cc src/data_generation/mock/generate_user.cc

View File

@@ -15,6 +15,7 @@
#include <string_view> #include <string_view>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
#include "data_generation/prompt_formatting/prompt_formatter.h"
#include "data_model/application_options.h" #include "data_model/application_options.h"
struct llama_model; struct llama_model;
@@ -31,9 +32,11 @@ class LlamaGenerator final : public DataGenerator {
* *
* @param options Parsed application options. * @param options Parsed application options.
* @param model_path Filesystem path to GGUF model assets. * @param model_path Filesystem path to GGUF model assets.
* @param prompt_formatter Formatter that produces model-specific prompts.
*/ */
LlamaGenerator(const ApplicationOptions& options, LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path); const std::string& model_path,
std::shared_ptr<IPromptFormatter> prompt_formatter);
~LlamaGenerator() override; ~LlamaGenerator() override;
@@ -132,6 +135,7 @@ class LlamaGenerator final : public DataGenerator {
std::mt19937 rng_; std::mt19937 rng_;
uint32_t n_ctx_ = kDefaultContextSize; uint32_t n_ctx_ = kDefaultContextSize;
std::string brewery_system_prompt_; std::string brewery_system_prompt_;
std::shared_ptr<IPromptFormatter> prompt_formatter_;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -12,7 +12,6 @@
#include <string> #include <string>
#include <string_view> #include <string_view>
struct llama_model;
struct llama_vocab; struct llama_vocab;
using llama_token = int32_t; using llama_token = int32_t;
@@ -26,18 +25,6 @@ using llama_token = int32_t;
std::string PrepareRegionContext(std::string_view region_context, std::string PrepareRegionContext(std::string_view region_context,
size_t max_chars = 2000); size_t max_chars = 2000);
/**
* @brief Applies model chat template to system and user prompts.
*
* @param model Loaded llama model.
* @param system_prompt System prompt text.
* @param user_prompt User prompt text.
* @return Model-formatted prompt.
*/
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt);
/** /**
* @brief Decodes a sampled token and appends it to output text. * @brief Decodes a sampled token and appends it to output text.
* *
@@ -58,6 +45,7 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
*/ */
std::optional<std::string> ValidateBreweryJson(const std::string& raw, std::optional<std::string> ValidateBreweryJson(const std::string& raw,
std::string& name_out, std::string& name_out,
std::string& description_out); std::string& description_out,
std::string& reasoning_out);
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_

View File

@@ -0,0 +1,15 @@
#pragma once
#include <string>
#include <string_view>
#include "data_generation/prompt_formatting/prompt_formatter.h"
class Gemma4JinjaPromptFormatter final : public IPromptFormatter {
public:
Gemma4JinjaPromptFormatter() = default;
~Gemma4JinjaPromptFormatter() override = default;
[[nodiscard]] std::string Format(std::string_view system_prompt,
std::string_view user_prompt) const override;
};

View File

@@ -0,0 +1,18 @@
#pragma once
#include <string>
#include <string_view>
class IPromptFormatter {
public:
IPromptFormatter() = default;
IPromptFormatter(const IPromptFormatter&) = delete;
IPromptFormatter& operator=(const IPromptFormatter&) = delete;
IPromptFormatter(IPromptFormatter&&) = delete;
IPromptFormatter& operator=(IPromptFormatter&&) = delete;
virtual ~IPromptFormatter() = default;
[[nodiscard]] virtual std::string Format(
std::string_view system_prompt,
std::string_view user_prompt) const = 0;
};

View File

@@ -1,4 +1,3 @@
<|think|>
Return only one raw JSON object as the final answer, with exactly three keys: "reasoning", "name", and "description". Return only one raw JSON object as the final answer, with exactly three keys: "reasoning", "name", and "description".
The "reasoning" key MUST be the first key in the object. The "reasoning" key MUST be the first key in the object.
No markdown, code fences, preamble, or extra keys. No markdown, code fences, preamble, or extra keys.

View File

@@ -25,6 +25,10 @@ escape ::= ["\\/bfnrt] | "u" hex hex hex hex
hex ::= [0-9a-fA-F] hex ::= [0-9a-fA-F]
)json_brewery"; )json_brewery";
static constexpr int kBreweryInitialMaxTokens = 2800;
static constexpr int kBreweryTruncationRetryTokenBump = 700;
static constexpr int kBreweryMaxTokensCeiling = 5000;
BreweryResult LlamaGenerator::GenerateBrewery( BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) { const Location& location, const std::string& region_context) {
/** /**
@@ -43,11 +47,8 @@ BreweryResult LlamaGenerator::GenerateBrewery(
const std::string system_prompt = const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md"); LoadBrewerySystemPrompt("prompts/system.md");
/**
* User prompt: provides geographic context to guide generation towards std::string user_prompt = std::format(
* culturally relevant and locally-inspired brewery attributes
*/
std::string prompt = std::format(
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}", "## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}",
location.city, location.country, safe_region_context); location.city, location.country, safe_region_context);
@@ -66,11 +67,14 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string raw; std::string raw;
std::string last_error; std::string last_error;
// Token budget: too small risks truncating valid JSON mid-string.
// Start conservatively but allow adaptive increases on truncation.
int max_tokens = kBreweryInitialMaxTokens;
// Limit output length to keep it concise and focused // Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) { for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM // Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens, kBreweryJsonGrammar); raw = this->Infer(system_prompt, user_prompt, max_tokens, kBreweryJsonGrammar);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw); raw);
@@ -78,10 +82,16 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string name; std::string name;
std::string description; std::string description;
std::string reasoning;
const std::optional<std::string> validation_error = const std::optional<std::string> validation_error =
ValidateBreweryJson(raw, name, description); ValidateBreweryJson(raw, name, description, reasoning);
if (!validation_error.has_value()) { if (!validation_error.has_value()) {
// Success: return parsed brewery data // Success: return parsed brewery data
spdlog::info(
"LlamaGenerator: successfully generated brewery data on attempt {}:\n reasoning='{}',\n name='{}',\n description='{}'",
attempt + 1, reasoning, name, description);
return BreweryResult{.name = std::move(name), return BreweryResult{.name = std::move(name),
.description = std::move(description)}; .description = std::move(description)};
} }
@@ -92,12 +102,27 @@ BreweryResult LlamaGenerator::GenerateBrewery(
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error); attempt + 1, *validation_error);
if (last_error == "JSON parse error: incomplete JSON") {
const int previous_max_tokens = max_tokens;
max_tokens = std::min(max_tokens + kBreweryTruncationRetryTokenBump,
kBreweryMaxTokensCeiling);
spdlog::info(
"LlamaGenerator: detected truncated JSON; increasing max_tokens from {} to {} and retrying",
previous_max_tokens, max_tokens);
continue;
}
// Update prompt with error details to guide LLM toward correct output. // Update prompt with error details to guide LLM toward correct output.
prompt = std::format( user_prompt = std::format(
R"(Your previous response was invalid. Error: {} R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "<brief planning summary>", "name": "<brewery name>", "description": "<single-paragraph description>"}}. Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "<brief planning summary>", "name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values. Do not include markdown, comments, extra keys, or literal placeholder values.
Keep the JSON strings concise enough to fit within the token budget.
{})", {})",
*validation_error, retry_location); *validation_error, retry_location);
} }

View File

@@ -4,8 +4,6 @@
* parsing, token decoding, and JSON validation helpers for Llama modules. * parsing, token decoding, and JSON validation helpers for Llama modules.
*/ */
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <boost/json.hpp> #include <boost/json.hpp>
@@ -81,89 +79,6 @@ std::string PrepareRegionContext(std::string_view region_context,
return normalized; return normalized;
} }
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt =
std::format("{}\n\n{}", system_prompt, user_prompt);
const char* template_str = llama_model_chat_template(model, nullptr);
// If metadata is missing (nullptr), attempt to use the built-in "gemma" alias
// to leverage the library's interleaved template for Gemma 4 support.
if (template_str == nullptr) {
template_str = "gemma";
spdlog::info(
"LlamaGenerator: model chat template metadata missing; attempting "
"built-in 'gemma' alias");
}
const std::array<llama_chat_message, 2> messages = {{
{.role = "system", .content = system_prompt.c_str()},
{.role = "user", .content = user_prompt.c_str()},
}};
constexpr std::size_t min_template_buffer_size = 1024;
std::vector<char> buffer(
std::max<std::size_t>(min_template_buffer_size,
(system_prompt.size() + user_prompt.size()) * 4));
auto apply_template_with_resize = [&](const char* tmpl,
const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (result < 0) {
return result;
}
const auto buffer_size = static_cast<int32_t>(buffer.size());
if (result >= buffer_size) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
return result;
};
int32_t template_result =
apply_template_with_resize(template_str, messages.data(), 2);
if (template_result >= 0) {
return {buffer.data(), static_cast<size_t>(template_result)};
}
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// FALLBACK: If the template fails (e.g., model rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
const std::array<llama_chat_message, 1> fallback_msg = {{
{.role = "user", .content = combined_prompt.c_str()},
}};
template_result =
apply_template_with_resize(template_str, fallback_msg.data(), 1);
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
return {buffer.data(), static_cast<size_t>(template_result)};
}
void AppendTokenPiece(const llama_vocab* vocab, llama_token token, void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
constexpr size_t initial_buffer_size = 256; constexpr size_t initial_buffer_size = 256;
@@ -193,6 +108,7 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
if (!buffer_too_small(bytes)) { if (!buffer_too_small(bytes)) {
output.append(dynamic_buffer.data(), static_cast<size_t>(bytes)); output.append(dynamic_buffer.data(), static_cast<size_t>(bytes));
return;
} }
throw std::runtime_error( throw std::runtime_error(
@@ -201,7 +117,8 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::optional<std::string> ValidateBreweryJson(const std::string& raw, std::optional<std::string> ValidateBreweryJson(const std::string& raw,
std::string& name_out, std::string& name_out,
std::string& description_out) { std::string& description_out,
std::string& reasoning_out) {
auto validate_object = [&](const boost::json::value& json_value, auto validate_object = [&](const boost::json::value& json_value,
std::string& error_out) -> bool { std::string& error_out) -> bool {
if (!json_value.is_object()) { if (!json_value.is_object()) {
@@ -209,7 +126,14 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
return false; return false;
} }
const auto& obj = json_value.get_object(); const auto& obj = json_value.get_object();
if (!obj.contains("reasoning") || !obj.at("reasoning").is_string()) {
error_out = "JSON field 'reasoning' is missing or not a string";
return false;
}
if (!obj.contains("name") || !obj.at("name").is_string()) { if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string"; error_out = "JSON field 'name' is missing or not a string";
return false; return false;
@@ -219,6 +143,12 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
error_out = "JSON field 'description' is missing or not a string"; error_out = "JSON field 'description' is missing or not a string";
return false; return false;
} }
const auto& reasoning_value = obj.at("reasoning").as_string();
reasoning_out = Trim(std::string_view(reasoning_value.data(), reasoning_value.size()));
if (reasoning_out.empty()) {
error_out = "JSON field 'reasoning' must not be empty";
return false;
}
const auto& name_value = obj.at("name").as_string(); const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string(); const auto& description_value = obj.at("description").as_string();
@@ -239,15 +169,16 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
std::string name_lower = name_out; std::string name_lower = name_out;
std::string description_lower = description_out; std::string description_lower = description_out;
std::ranges::transform(name_lower, name_lower.begin(),
[](unsigned char character) {
return static_cast<char>(std::tolower(character));
});
std::ranges::transform(description_lower, description_lower.begin(), auto string_to_lower = [](std::string& str_out) {
[](unsigned char character) { std::ranges::transform(str_out, str_out.begin(),
return static_cast<char>(std::tolower(character)); [](unsigned char character) {
}); return static_cast<char>(std::tolower(character));
});
};
string_to_lower(name_lower);
string_to_lower(description_lower);
if (name_lower == "string" || description_lower == "string") { if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content"; error_out = "JSON appears to be a schema placeholder, not content";

View File

@@ -75,7 +75,7 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, const std::string& prompt,
const int max_tokens, const int max_tokens,
std::string_view grammar) { std::string_view grammar) {
return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt), return InferFormatted(prompt_formatter_->Format(system_prompt, prompt),
max_tokens, grammar); max_tokens, grammar);
} }

View File

@@ -31,12 +31,19 @@ void LlamaGenerator::ContextDeleter::operator()(
} }
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) const std::string& model_path,
: rng_(std::random_device{}()) { std::shared_ptr<IPromptFormatter> prompt_formatter)
: rng_(std::random_device{}()),
prompt_formatter_(std::move(prompt_formatter)) {
if (model_path.empty()) { if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty"); throw std::runtime_error("LlamaGenerator: model path must not be empty");
} }
if (!prompt_formatter_) {
throw std::runtime_error(
"LlamaGenerator: prompt formatter dependency must not be null");
}
if (options.temperature < 0.0F) { if (options.temperature < 0.0F) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0"); "LlamaGenerator: sampling temperature must be >= 0");

View File

@@ -0,0 +1,32 @@
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
#include <format>
#include <string>
#include <string_view>
static constexpr std::string_view kWhitespace = " \t\n\r\f\v";
// Strips leading and trailing whitespace to ensure clean prompt injection.
static std::string_view Trim(std::string_view value) {
const size_t first_index = value.find_first_not_of(kWhitespace);
const bool is_all_whitespace = (first_index == std::string_view::npos);
if (is_all_whitespace) {
return "";
}
const size_t last_index = value.find_last_not_of(kWhitespace);
return value.substr(first_index, last_index - first_index + 1);
}
std::string Gemma4JinjaPromptFormatter::Format(
std::string_view system_prompt, std::string_view user_prompt) const {
std::string_view trimmed_system = Trim(system_prompt);
std::string_view trimmed_user = Trim(user_prompt);
return std::format(
"<|turn|>system\n<|think|>\n{}\n<|turn|>\n"
"<|turn|>user\n{}\n<|turn|>\n"
"<|turn|>model\n<|channel>thought\n",
trimmed_system, trimmed_user);
}

View File

@@ -17,6 +17,7 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
#include "data_model/application_options.h" #include "data_model/application_options.h"
#include "llama_backend_state.h" #include "llama_backend_state.h"
#include "services/enrichment_service.h" #include "services/enrichment_service.h"
@@ -147,6 +148,7 @@ int main(const int argc, char** argv) {
di::bind<WebClient>().to<CURLWebClient>(), di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options), di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(), di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
di::bind<std::string>().to(options.model_path), di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to( di::bind<DataGenerator>().to(
[options](const auto& inj) -> std::unique_ptr<DataGenerator> { [options](const auto& inj) -> std::unique_ptr<DataGenerator> {