Refactor Llama generator, helpers, and build assets

make Gemma 4 the default model, enable thinking mode
style updates
This commit is contained in:
Aaron Po
2026-04-10 00:03:45 -04:00
parent 7ca651a886
commit 56ec728ba7
61 changed files with 1430 additions and 1905 deletions

View File

@@ -6,65 +6,109 @@
#include <spdlog/spdlog.h>
#include <array>
#include <format>
#include <optional>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](const std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
const std::size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1);
};
static constexpr std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
std::size_t separator_pos = std::string::npos;
std::size_t separator_length = 0;
for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos ||
candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length);
}
const std::string_view trimmed = trim(raw_response);
const std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
return std::string(trimmed);
}
BreweryResult LlamaGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const Location& location, const std::string& region_context) {
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
* Default path: prompts/brewery_system_prompt_expanded.txt
*/
const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/brewery_system_prompt_expanded.txt");
LoadBrewerySystemPrompt("prompts/system.md");
/**
* User prompt: provides geographic context to guide generation towards
* culturally appropriate and locally-inspired brewery attributes
* culturally relevant and locally-inspired brewery attributes
*/
std::string prompt =
std::string prompt = std::format(
"Write a brewery name and place-specific long description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
"brewery in {}{}{}",
location.city, country_suffix, region_suffix);
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
"Location: " + city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name);
std::format("Location: {}{}", location.city, country_suffix);
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement
*/
const int max_attempts = 3;
constexpr int max_attempts = 3;
std::string raw;
std::string last_error;
// Limit output length to keep it concise and focused
constexpr int max_tokens = 1052;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM
raw = Infer(system_prompt, prompt, max_tokens);
raw = this->Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
@@ -72,29 +116,29 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string name;
std::string description;
const std::string validation_error =
ValidateBreweryJsonPublic(raw, name, description);
if (validation_error.empty()) {
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) {
// Success: return parsed brewery data
return {std::move(name), std::move(description)};
return BreweryResult{.name = std::move(name),
.description = std::move(description)};
}
// Validation failed: log error and prepare corrective feedback
last_error = validation_error;
last_error = *validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
attempt + 1, *validation_error);
// Update prompt with error details to guide LLM toward correct output.
// For retries, use a compact prompt format to avoid exceeding token
// limits.
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\n" +
retry_location;
prompt = std::format(
R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values.
{})",
*validation_error, retry_location);
}
// All retry attempts exhausted: log failure and throw exception