diff --git a/pipeline/includes/data_generation/llama_generator_helpers.h b/pipeline/includes/data_generation/llama_generator_helpers.h index b1e26db..eabecba 100644 --- a/pipeline/includes/data_generation/llama_generator_helpers.h +++ b/pipeline/includes/data_generation/llama_generator_helpers.h @@ -58,6 +58,7 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token, */ std::optional ValidateBreweryJson(const std::string& raw, std::string& name_out, - std::string& description_out); + std::string& description_out, + std::string& reasoning_out); #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ diff --git a/pipeline/src/data_generation/llama/generate_brewery.cc b/pipeline/src/data_generation/llama/generate_brewery.cc index 2c90ef8..e9f2210 100644 --- a/pipeline/src/data_generation/llama/generate_brewery.cc +++ b/pipeline/src/data_generation/llama/generate_brewery.cc @@ -43,11 +43,8 @@ BreweryResult LlamaGenerator::GenerateBrewery( const std::string system_prompt = LoadBrewerySystemPrompt("prompts/system.md"); - /** - * User prompt: provides geographic context to guide generation towards - * culturally relevant and locally-inspired brewery attributes - */ - std::string prompt = std::format( + + std::string user_prompt = std::format( "## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}", location.city, location.country, safe_region_context); @@ -70,7 +67,7 @@ BreweryResult LlamaGenerator::GenerateBrewery( for (int attempt = 0; attempt < max_attempts; ++attempt) { constexpr int max_tokens = 1052; // Generate brewery data from LLM - raw = this->Infer(system_prompt, prompt, max_tokens, kBreweryJsonGrammar); + raw = this->Infer(system_prompt, user_prompt, max_tokens, kBreweryJsonGrammar); spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, raw); @@ -78,10 +75,16 @@ BreweryResult LlamaGenerator::GenerateBrewery( std::string name; std::string description; + std::string reasoning; const std::optional validation_error = - ValidateBreweryJson(raw, name, description); + ValidateBreweryJson(raw, name, description, reasoning); if (!validation_error.has_value()) { // Success: return parsed brewery data + + spdlog::info( + "LlamaGenerator: successfully generated brewery data on attempt {}:\n reasoning='{}',\n name='{}',\n description='{}'", + attempt + 1, reasoning, name, description); + return BreweryResult{.name = std::move(name), .description = std::move(description)}; } @@ -93,7 +96,7 @@ BreweryResult LlamaGenerator::GenerateBrewery( attempt + 1, *validation_error); // Update prompt with error details to guide LLM toward correct output. - prompt = std::format( + user_prompt = std::format( R"(Your previous response was invalid. Error: {} Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "", "name": "", "description": ""}}. Do not include markdown, comments, extra keys, or literal placeholder values. diff --git a/pipeline/src/data_generation/llama/helpers.cc b/pipeline/src/data_generation/llama/helpers.cc index 454d340..a214e22 100644 --- a/pipeline/src/data_generation/llama/helpers.cc +++ b/pipeline/src/data_generation/llama/helpers.cc @@ -201,7 +201,8 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token, std::optional ValidateBreweryJson(const std::string& raw, std::string& name_out, - std::string& description_out) { + std::string& description_out, + std::string& reasoning_out) { auto validate_object = [&](const boost::json::value& json_value, std::string& error_out) -> bool { if (!json_value.is_object()) { @@ -209,7 +210,14 @@ std::optional ValidateBreweryJson(const std::string& raw, return false; } + const auto& obj = json_value.get_object(); + + if (!obj.contains("reasoning") || !obj.at("reasoning").is_string()) { + error_out = "JSON field 'reasoning' is missing or not a string"; + return false; + } + if (!obj.contains("name") || !obj.at("name").is_string()) { error_out = "JSON field 'name' is missing or not a string"; return false; @@ -219,6 +227,12 @@ std::optional ValidateBreweryJson(const std::string& raw, error_out = "JSON field 'description' is missing or not a string"; return false; } + const auto& reasoning_value = obj.at("reasoning").as_string(); + reasoning_out = Trim(std::string_view(reasoning_value.data(), reasoning_value.size())); + if (reasoning_out.empty()) { + error_out = "JSON field 'reasoning' must not be empty"; + return false; + } const auto& name_value = obj.at("name").as_string(); const auto& description_value = obj.at("description").as_string(); @@ -239,15 +253,16 @@ std::optional ValidateBreweryJson(const std::string& raw, std::string name_lower = name_out; std::string description_lower = description_out; - std::ranges::transform(name_lower, name_lower.begin(), - [](unsigned char character) { - return static_cast(std::tolower(character)); - }); - std::ranges::transform(description_lower, description_lower.begin(), - [](unsigned char character) { - return static_cast(std::tolower(character)); - }); + auto string_to_lower = [](std::string& str_out) { + std::ranges::transform(str_out, str_out.begin(), + [](unsigned char character) { + return static_cast(std::tolower(character)); + }); + }; + + string_to_lower(name_lower); + string_to_lower(description_lower); if (name_lower == "string" || description_lower == "string") { error_out = "JSON appears to be a schema placeholder, not content";