Add local language handling

This commit is contained in:
Aaron Po
2026-04-18 00:43:05 -04:00
parent f782fdb51d
commit 9649c993e8
11 changed files with 300 additions and 709 deletions

View File

@@ -13,7 +13,7 @@
#include "biergarten_data_generator.h"
#include "json_handling/json_loader.h"
static constexpr size_t kBreweryAmount = 5;
static constexpr size_t kBreweryAmount = 50;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");

View File

@@ -19,7 +19,7 @@ bool BiergartenDataGenerator::Run() {
for (auto& city : cities) {
try {
std::string region_context = context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
enriched.push_back(

View File

@@ -4,18 +4,35 @@
* inference, and validates structured JSON output for brewery records.
*/
#include "data_generation/llama_generator.h"
#include <spdlog/spdlog.h>
#include <format>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
static std::string FormatLocalLanguageCodes(
const std::vector<std::string>& codes) {
if (codes.empty()) {
return "Not provided";
}
std::string formatted;
for (const std::string& code : codes) {
if (!formatted.empty()) {
formatted += ", ";
}
formatted += code;
}
return formatted;
}
static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery(
root ::= thought-block "{" ws "\"name_en\"" ws ":" ws string ws "," ws "\"description_en\"" ws ":" ws string ws "," ws "\"name_local\"" ws ":" ws string ws "," ws "\"description_local\"" ws ":" ws string ws "}" ws
thought-block ::= [^{]*
@@ -35,8 +52,10 @@ BreweryResult LlamaGenerator::GenerateBrewery(
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContext(region_context);
const std::string safe_region_context = PrepareRegionContext(region_context);
const std::string local_language_codes =
FormatLocalLanguageCodes(location.local_languages);
const std::string country_suffix =
location.country.empty() ? std::string{}
@@ -48,16 +67,18 @@ BreweryResult LlamaGenerator::GenerateBrewery(
const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md");
std::string user_prompt = std::format(
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}",
location.city, location.country, safe_region_context);
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## LOCAL LANGUAGE CODES:\n{}\n\n## "
"CONTEXT:\n{}",
location.city, location.country, local_language_codes,
safe_region_context);
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
std::format("Location: {}{}", location.city, country_suffix);
std::format("Location: {}{}\nLocal language codes: {}", location.city,
country_suffix, local_language_codes);
/**
* RETRY LOOP with validation and error correction
@@ -68,16 +89,17 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string raw;
std::string last_error;
// Token budget: too small risks truncating valid JSON mid-string.
// Start conservatively but allow adaptive increases on truncation.
int max_tokens = kBreweryInitialMaxTokens;
// Token budget: too small risks truncating valid JSON mid-string.
// Start conservatively but allow adaptive increases on truncation.
int max_tokens = kBreweryInitialMaxTokens;
// Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) {
// Generate brewery data from LLM
raw = this->Infer(system_prompt, user_prompt, max_tokens, kBreweryJsonGrammar);
raw = this->Infer(system_prompt, user_prompt, max_tokens,
kBreweryJsonGrammar);
spdlog::info("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
raw);
// Validate output: parse JSON and check required fields
@@ -89,9 +111,8 @@ BreweryResult LlamaGenerator::GenerateBrewery(
// Success: return parsed brewery data
spdlog::info(
"LlamaGenerator: successfully generated brewery data on attempt {}:\n name_en='{}',\n description_en='{}',\n name_local='{}',\n description_local='{}'",
attempt + 1, brewery.name_en, brewery.description_en,
brewery.name_local, brewery.description_local);
"LlamaGenerator: successfully generated brewery data on attempt {}",
attempt + 1);
return brewery;
}
@@ -102,13 +123,13 @@ BreweryResult LlamaGenerator::GenerateBrewery(
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error);
if (last_error == "JSON parse error: incomplete JSON") {
const int previous_max_tokens = max_tokens;
max_tokens = std::min(max_tokens + kBreweryTruncationRetryTokenBump,
kBreweryMaxTokensCeiling);
if (last_error == "JSON parse error: incomplete JSON") {
const int previous_max_tokens = max_tokens;
max_tokens = std::min(max_tokens + kBreweryTruncationRetryTokenBump,
kBreweryMaxTokensCeiling);
spdlog::info(
"LlamaGenerator: detected truncated JSON; increasing max_tokens from {} to {} and retrying",
"LlamaGenerator: detected truncated JSON; increasing max_tokens from "
"{} to {} and retrying",
previous_max_tokens, max_tokens);
continue;
@@ -116,13 +137,15 @@ BreweryResult LlamaGenerator::GenerateBrewery(
// Update prompt with error details to guide LLM toward correct output.
user_prompt = std::format(
R"(Your previous response was invalid. Error: {}
Return the thought process before the JSON if needed, then return ONLY valid JSON with exactly these keys, in this exact order: {{"name_en": "<English brewery name>", "description_en": "<English single-paragraph description>", "name_local": "<local-language brewery name>", "description_local": "<local-language single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values.
Keep the JSON strings concise enough to fit within the token budget.
{})",
"Your previous response was invalid. Error: {}\nReturn the thought "
"process before the JSON if needed, then return ONLY valid JSON with "
"exactly these keys, in this exact order: {{\"name_en\": \"<English "
"brewery name>\", \"description_en\": \"<English single-paragraph "
"description>\", \"name_local\": \"<local-language brewery name>\", "
"\"description_local\": \"<local-language single-paragraph "
"description>\"}}.\nDo not include markdown, comments, extra keys, or "
"literal placeholder values.\n\nKeep the JSON strings concise enough "
"to fit within the token budget.\n\n{}",
*validation_error, retry_location);
}

View File

@@ -32,7 +32,7 @@ void LlamaGenerator::ContextDeleter::operator()(
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path,
std::shared_ptr<IPromptFormatter> prompt_formatter)
std::unique_ptr<IPromptFormatter> prompt_formatter)
: rng_(std::random_device{}()),
prompt_formatter_(std::move(prompt_formatter)) {
if (model_path.empty()) {

View File

@@ -35,6 +35,27 @@ static double ReadRequiredNumber(const boost::json::object& object,
return value->to_number<double>();
}
static std::vector<std::string> ReadRequiredStringArray(
const boost::json::object& object, const char* key) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_array()) {
throw std::runtime_error(std::string("Missing or invalid string array field: ") +
key);
}
const auto& array = value->as_array();
std::vector<std::string> items;
items.reserve(array.size());
for (const auto& item : array) {
if (!item.is_string()) {
throw std::runtime_error(std::string("Missing or invalid string array field: ") +
key);
}
items.emplace_back(item.as_string());
}
return items;
}
std::vector<Location> JsonLoader::LoadLocations(
const std::filesystem::path& filepath) {
std::ifstream input(filepath);
@@ -76,6 +97,8 @@ std::vector<Location> JsonLoader::LoadLocations(
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.local_languages =
ReadRequiredStringArray(object, "local_languages"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});