From fcc7a5dc8b24587e150179fa0b06b510a5823cd0 Mon Sep 17 00:00:00 2001 From: Aaron Po Date: Thu, 16 Apr 2026 20:06:36 -0400 Subject: [PATCH] Enhance ValidateBreweryJson to include reasoning output and update GenerateBrewery to use user_prompt Add gemma parser --- pipeline/CMakeLists.txt | 1 + .../data_generation/llama_generator.h | 6 +- .../data_generation/llama_generator_helpers.h | 16 +-- .../gemma4_jinja_prompt_formatter.h | 15 +++ .../prompt_formatting/prompt_formatter.h | 18 +++ pipeline/prompts/system.md | 1 - .../data_generation/llama/generate_brewery.cc | 43 +++++-- pipeline/src/data_generation/llama/helpers.cc | 119 ++++-------------- pipeline/src/data_generation/llama/infer.cc | 2 +- .../data_generation/llama/llama_generator.cc | 11 +- .../gemma4_jinja_prompt_formatter.cc | 32 +++++ pipeline/src/main.cc | 2 + 12 files changed, 144 insertions(+), 122 deletions(-) create mode 100644 pipeline/includes/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h create mode 100644 pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h create mode 100644 pipeline/src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc diff --git a/pipeline/CMakeLists.txt b/pipeline/CMakeLists.txt index 9771444..02f769b 100644 --- a/pipeline/CMakeLists.txt +++ b/pipeline/CMakeLists.txt @@ -107,6 +107,7 @@ set(SOURCES src/data_generation/llama/infer.cc src/data_generation/llama/load.cc src/data_generation/llama/load_brewery_prompt.cc + src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc src/data_generation/mock/deterministic_hash.cc src/data_generation/mock/generate_brewery.cc src/data_generation/mock/generate_user.cc diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index ef0189b..4b7729f 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -15,6 +15,7 @@ #include #include "data_generation/data_generator.h" +#include "data_generation/prompt_formatting/prompt_formatter.h" #include "data_model/application_options.h" struct llama_model; @@ -31,9 +32,11 @@ class LlamaGenerator final : public DataGenerator { * * @param options Parsed application options. * @param model_path Filesystem path to GGUF model assets. + * @param prompt_formatter Formatter that produces model-specific prompts. */ LlamaGenerator(const ApplicationOptions& options, - const std::string& model_path); + const std::string& model_path, + std::shared_ptr prompt_formatter); ~LlamaGenerator() override; @@ -132,6 +135,7 @@ class LlamaGenerator final : public DataGenerator { std::mt19937 rng_; uint32_t n_ctx_ = kDefaultContextSize; std::string brewery_system_prompt_; + std::shared_ptr prompt_formatter_; }; #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_ diff --git a/pipeline/includes/data_generation/llama_generator_helpers.h b/pipeline/includes/data_generation/llama_generator_helpers.h index b1e26db..d2224a8 100644 --- a/pipeline/includes/data_generation/llama_generator_helpers.h +++ b/pipeline/includes/data_generation/llama_generator_helpers.h @@ -12,7 +12,6 @@ #include #include -struct llama_model; struct llama_vocab; using llama_token = int32_t; @@ -26,18 +25,6 @@ using llama_token = int32_t; std::string PrepareRegionContext(std::string_view region_context, size_t max_chars = 2000); -/** - * @brief Applies model chat template to system and user prompts. - * - * @param model Loaded llama model. - * @param system_prompt System prompt text. - * @param user_prompt User prompt text. - * @return Model-formatted prompt. - */ -std::string ToChatPrompt(const llama_model* model, - const std::string& system_prompt, - const std::string& user_prompt); - /** * @brief Decodes a sampled token and appends it to output text. * @@ -58,6 +45,7 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token, */ std::optional ValidateBreweryJson(const std::string& raw, std::string& name_out, - std::string& description_out); + std::string& description_out, + std::string& reasoning_out); #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ diff --git a/pipeline/includes/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h b/pipeline/includes/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h new file mode 100644 index 0000000..cf08bbc --- /dev/null +++ b/pipeline/includes/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include "data_generation/prompt_formatting/prompt_formatter.h" + +class Gemma4JinjaPromptFormatter final : public IPromptFormatter { + public: + Gemma4JinjaPromptFormatter() = default; + ~Gemma4JinjaPromptFormatter() override = default; + + [[nodiscard]] std::string Format(std::string_view system_prompt, + std::string_view user_prompt) const override; +}; diff --git a/pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h b/pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h new file mode 100644 index 0000000..7498397 --- /dev/null +++ b/pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +class IPromptFormatter { + public: + IPromptFormatter() = default; + IPromptFormatter(const IPromptFormatter&) = delete; + IPromptFormatter& operator=(const IPromptFormatter&) = delete; + IPromptFormatter(IPromptFormatter&&) = delete; + IPromptFormatter& operator=(IPromptFormatter&&) = delete; + virtual ~IPromptFormatter() = default; + + [[nodiscard]] virtual std::string Format( + std::string_view system_prompt, + std::string_view user_prompt) const = 0; +}; diff --git a/pipeline/prompts/system.md b/pipeline/prompts/system.md index 4c26f4d..0eac845 100644 --- a/pipeline/prompts/system.md +++ b/pipeline/prompts/system.md @@ -1,4 +1,3 @@ -<|think|> Return only one raw JSON object as the final answer, with exactly three keys: "reasoning", "name", and "description". The "reasoning" key MUST be the first key in the object. No markdown, code fences, preamble, or extra keys. diff --git a/pipeline/src/data_generation/llama/generate_brewery.cc b/pipeline/src/data_generation/llama/generate_brewery.cc index 2c90ef8..028af9f 100644 --- a/pipeline/src/data_generation/llama/generate_brewery.cc +++ b/pipeline/src/data_generation/llama/generate_brewery.cc @@ -25,6 +25,10 @@ escape ::= ["\\/bfnrt] | "u" hex hex hex hex hex ::= [0-9a-fA-F] )json_brewery"; +static constexpr int kBreweryInitialMaxTokens = 2800; +static constexpr int kBreweryTruncationRetryTokenBump = 700; +static constexpr int kBreweryMaxTokensCeiling = 5000; + BreweryResult LlamaGenerator::GenerateBrewery( const Location& location, const std::string& region_context) { /** @@ -43,11 +47,8 @@ BreweryResult LlamaGenerator::GenerateBrewery( const std::string system_prompt = LoadBrewerySystemPrompt("prompts/system.md"); - /** - * User prompt: provides geographic context to guide generation towards - * culturally relevant and locally-inspired brewery attributes - */ - std::string prompt = std::format( + + std::string user_prompt = std::format( "## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}", location.city, location.country, safe_region_context); @@ -66,11 +67,14 @@ BreweryResult LlamaGenerator::GenerateBrewery( std::string raw; std::string last_error; + // Token budget: too small risks truncating valid JSON mid-string. + // Start conservatively but allow adaptive increases on truncation. + int max_tokens = kBreweryInitialMaxTokens; + // Limit output length to keep it concise and focused for (int attempt = 0; attempt < max_attempts; ++attempt) { - constexpr int max_tokens = 1052; // Generate brewery data from LLM - raw = this->Infer(system_prompt, prompt, max_tokens, kBreweryJsonGrammar); + raw = this->Infer(system_prompt, user_prompt, max_tokens, kBreweryJsonGrammar); spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, raw); @@ -78,10 +82,16 @@ BreweryResult LlamaGenerator::GenerateBrewery( std::string name; std::string description; + std::string reasoning; const std::optional validation_error = - ValidateBreweryJson(raw, name, description); + ValidateBreweryJson(raw, name, description, reasoning); if (!validation_error.has_value()) { // Success: return parsed brewery data + + spdlog::info( + "LlamaGenerator: successfully generated brewery data on attempt {}:\n reasoning='{}',\n name='{}',\n description='{}'", + attempt + 1, reasoning, name, description); + return BreweryResult{.name = std::move(name), .description = std::move(description)}; } @@ -92,12 +102,27 @@ BreweryResult LlamaGenerator::GenerateBrewery( spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", attempt + 1, *validation_error); + + if (last_error == "JSON parse error: incomplete JSON") { + const int previous_max_tokens = max_tokens; + max_tokens = std::min(max_tokens + kBreweryTruncationRetryTokenBump, + kBreweryMaxTokensCeiling); + spdlog::info( + "LlamaGenerator: detected truncated JSON; increasing max_tokens from {} to {} and retrying", + previous_max_tokens, max_tokens); + + + continue; + } + // Update prompt with error details to guide LLM toward correct output. - prompt = std::format( + user_prompt = std::format( R"(Your previous response was invalid. Error: {} Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "", "name": "", "description": ""}}. Do not include markdown, comments, extra keys, or literal placeholder values. +Keep the JSON strings concise enough to fit within the token budget. + {})", *validation_error, retry_location); } diff --git a/pipeline/src/data_generation/llama/helpers.cc b/pipeline/src/data_generation/llama/helpers.cc index 454d340..35f88bb 100644 --- a/pipeline/src/data_generation/llama/helpers.cc +++ b/pipeline/src/data_generation/llama/helpers.cc @@ -4,8 +4,6 @@ * parsing, token decoding, and JSON validation helpers for Llama modules. */ -#include - #include #include #include @@ -81,89 +79,6 @@ std::string PrepareRegionContext(std::string_view region_context, return normalized; } -std::string ToChatPrompt(const llama_model* model, - const std::string& system_prompt, - const std::string& user_prompt) { - std::string combined_prompt = - std::format("{}\n\n{}", system_prompt, user_prompt); - - const char* template_str = llama_model_chat_template(model, nullptr); - - // If metadata is missing (nullptr), attempt to use the built-in "gemma" alias - // to leverage the library's interleaved template for Gemma 4 support. - if (template_str == nullptr) { - template_str = "gemma"; - spdlog::info( - "LlamaGenerator: model chat template metadata missing; attempting " - "built-in 'gemma' alias"); - } - - const std::array messages = {{ - {.role = "system", .content = system_prompt.c_str()}, - {.role = "user", .content = user_prompt.c_str()}, - }}; - - constexpr std::size_t min_template_buffer_size = 1024; - - std::vector buffer( - std::max(min_template_buffer_size, - (system_prompt.size() + user_prompt.size()) * 4)); - - auto apply_template_with_resize = [&](const char* tmpl, - const llama_chat_message* chat_messages, - int32_t message_count) -> int32_t { - int32_t result = llama_chat_apply_template( - tmpl, chat_messages, message_count, true, buffer.data(), - static_cast(buffer.size())); - - if (result < 0) { - return result; - } - - const auto buffer_size = static_cast(buffer.size()); - if (result >= buffer_size) { - buffer.resize(static_cast(result) + 1); - result = llama_chat_apply_template( - tmpl, chat_messages, message_count, true, buffer.data(), - static_cast(buffer.size())); - } - - return result; - }; - - int32_t template_result = - apply_template_with_resize(template_str, messages.data(), 2); - - if (template_result >= 0) { - return {buffer.data(), static_cast(template_result)}; - } - - spdlog::warn( - "LlamaGenerator: chat template rejected system/user messages (result " - "{}); trying single user fallback", - template_result); - - // FALLBACK: If the template fails (e.g., model rejecting the "system" role), - // combine the system and user prompts into a single "user" message. - const std::array fallback_msg = {{ - {.role = "user", .content = combined_prompt.c_str()}, - }}; - - template_result = - apply_template_with_resize(template_str, fallback_msg.data(), 1); - - // Ultimate fallback: if GGUF template parsing still fails, use raw text. - if (template_result < 0) { - spdlog::warn( - "LlamaGenerator: chat template fallback failed (result {}); using " - "raw prompt text", - template_result); - return combined_prompt; - } - - return {buffer.data(), static_cast(template_result)}; -} - void AppendTokenPiece(const llama_vocab* vocab, llama_token token, std::string& output) { constexpr size_t initial_buffer_size = 256; @@ -193,6 +108,7 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token, if (!buffer_too_small(bytes)) { output.append(dynamic_buffer.data(), static_cast(bytes)); + return; } throw std::runtime_error( @@ -201,7 +117,8 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token, std::optional ValidateBreweryJson(const std::string& raw, std::string& name_out, - std::string& description_out) { + std::string& description_out, + std::string& reasoning_out) { auto validate_object = [&](const boost::json::value& json_value, std::string& error_out) -> bool { if (!json_value.is_object()) { @@ -209,7 +126,14 @@ std::optional ValidateBreweryJson(const std::string& raw, return false; } + const auto& obj = json_value.get_object(); + + if (!obj.contains("reasoning") || !obj.at("reasoning").is_string()) { + error_out = "JSON field 'reasoning' is missing or not a string"; + return false; + } + if (!obj.contains("name") || !obj.at("name").is_string()) { error_out = "JSON field 'name' is missing or not a string"; return false; @@ -219,6 +143,12 @@ std::optional ValidateBreweryJson(const std::string& raw, error_out = "JSON field 'description' is missing or not a string"; return false; } + const auto& reasoning_value = obj.at("reasoning").as_string(); + reasoning_out = Trim(std::string_view(reasoning_value.data(), reasoning_value.size())); + if (reasoning_out.empty()) { + error_out = "JSON field 'reasoning' must not be empty"; + return false; + } const auto& name_value = obj.at("name").as_string(); const auto& description_value = obj.at("description").as_string(); @@ -239,15 +169,16 @@ std::optional ValidateBreweryJson(const std::string& raw, std::string name_lower = name_out; std::string description_lower = description_out; - std::ranges::transform(name_lower, name_lower.begin(), - [](unsigned char character) { - return static_cast(std::tolower(character)); - }); - std::ranges::transform(description_lower, description_lower.begin(), - [](unsigned char character) { - return static_cast(std::tolower(character)); - }); + auto string_to_lower = [](std::string& str_out) { + std::ranges::transform(str_out, str_out.begin(), + [](unsigned char character) { + return static_cast(std::tolower(character)); + }); + }; + + string_to_lower(name_lower); + string_to_lower(description_lower); if (name_lower == "string" || description_lower == "string") { error_out = "JSON appears to be a schema placeholder, not content"; diff --git a/pipeline/src/data_generation/llama/infer.cc b/pipeline/src/data_generation/llama/infer.cc index e3604b7..81a3754 100644 --- a/pipeline/src/data_generation/llama/infer.cc +++ b/pipeline/src/data_generation/llama/infer.cc @@ -75,7 +75,7 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, const int max_tokens, std::string_view grammar) { - return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt), + return InferFormatted(prompt_formatter_->Format(system_prompt, prompt), max_tokens, grammar); } diff --git a/pipeline/src/data_generation/llama/llama_generator.cc b/pipeline/src/data_generation/llama/llama_generator.cc index 61ddf8a..8a54e7e 100644 --- a/pipeline/src/data_generation/llama/llama_generator.cc +++ b/pipeline/src/data_generation/llama/llama_generator.cc @@ -31,12 +31,19 @@ void LlamaGenerator::ContextDeleter::operator()( } LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, - const std::string& model_path) - : rng_(std::random_device{}()) { + const std::string& model_path, + std::shared_ptr prompt_formatter) + : rng_(std::random_device{}()), + prompt_formatter_(std::move(prompt_formatter)) { if (model_path.empty()) { throw std::runtime_error("LlamaGenerator: model path must not be empty"); } + if (!prompt_formatter_) { + throw std::runtime_error( + "LlamaGenerator: prompt formatter dependency must not be null"); + } + if (options.temperature < 0.0F) { throw std::runtime_error( "LlamaGenerator: sampling temperature must be >= 0"); diff --git a/pipeline/src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc b/pipeline/src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc new file mode 100644 index 0000000..dfe94b6 --- /dev/null +++ b/pipeline/src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc @@ -0,0 +1,32 @@ +#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h" + +#include +#include +#include + +static constexpr std::string_view kWhitespace = " \t\n\r\f\v"; + +// Strips leading and trailing whitespace to ensure clean prompt injection. +static std::string_view Trim(std::string_view value) { + const size_t first_index = value.find_first_not_of(kWhitespace); + + const bool is_all_whitespace = (first_index == std::string_view::npos); + if (is_all_whitespace) { + return ""; + } + + const size_t last_index = value.find_last_not_of(kWhitespace); + return value.substr(first_index, last_index - first_index + 1); +} + +std::string Gemma4JinjaPromptFormatter::Format( + std::string_view system_prompt, std::string_view user_prompt) const { + std::string_view trimmed_system = Trim(system_prompt); + std::string_view trimmed_user = Trim(user_prompt); + + return std::format( + "<|turn|>system\n<|think|>\n{}\n<|turn|>\n" + "<|turn|>user\n{}\n<|turn|>\n" + "<|turn|>model\n<|channel>thought\n", + trimmed_system, trimmed_user); +} diff --git a/pipeline/src/main.cc b/pipeline/src/main.cc index 65699a7..024318b 100644 --- a/pipeline/src/main.cc +++ b/pipeline/src/main.cc @@ -17,6 +17,7 @@ #include "biergarten_data_generator.h" #include "data_generation/llama_generator.h" #include "data_generation/mock_generator.h" +#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h" #include "data_model/application_options.h" #include "llama_backend_state.h" #include "services/enrichment_service.h" @@ -147,6 +148,7 @@ int main(const int argc, char** argv) { di::bind().to(), di::bind().to(options), di::bind().to(), + di::bind().to(), di::bind().to(options.model_path), di::bind().to( [options](const auto& inj) -> std::unique_ptr {