diff --git a/pipeline/includes/biergarten_data_generator.h b/pipeline/includes/biergarten_data_generator.h index a41575b..1fbe00e 100644 --- a/pipeline/includes/biergarten_data_generator.h +++ b/pipeline/includes/biergarten_data_generator.h @@ -33,6 +33,10 @@ struct ApplicationOptions { /// random). float top_p = 0.92f; + /// @brief Context window size (tokens) for LLM inference. Higher values + /// support longer prompts but use more memory. + uint32_t n_ctx = 2048; + /// @brief Random seed for sampling (-1 for random, otherwise non-negative). int seed = -1; diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index a996574..92d7d52 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -16,6 +16,8 @@ class LlamaGenerator final : public DataGenerator { void SetSamplingOptions(float temperature, float top_p, int seed = -1); + void SetContextSize(uint32_t n_ctx); + void Load(const std::string& model_path) override; BreweryResult GenerateBrewery(const std::string& city_name, const std::string& country_name, @@ -39,6 +41,7 @@ class LlamaGenerator final : public DataGenerator { float sampling_temperature_ = 0.8f; float sampling_top_p_ = 0.92f; uint32_t sampling_seed_ = 0xFFFFFFFFu; + uint32_t n_ctx_ = 2048; }; #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_ diff --git a/pipeline/includes/database/database.h b/pipeline/includes/database/database.h index 03307fe..01b2f23 100644 --- a/pipeline/includes/database/database.h +++ b/pipeline/includes/database/database.h @@ -59,6 +59,9 @@ class SqliteDatabase { /// @brief Commits the active database transaction. void CommitTransaction(); + /// @brief Rolls back the active database transaction. + void RollbackTransaction(); + /// @brief Inserts a country row. void InsertCountry(int id, const std::string& name, const std::string& iso2, const std::string& iso3); diff --git a/pipeline/src/biergarten_data_generator.cpp b/pipeline/src/biergarten_data_generator.cpp index 11dc4d6..6663b5e 100644 --- a/pipeline/src/biergarten_data_generator.cpp +++ b/pipeline/src/biergarten_data_generator.cpp @@ -28,11 +28,12 @@ std::unique_ptr BiergartenDataGenerator::InitializeGenerator() { auto llama_generator = std::make_unique(); llama_generator->SetSamplingOptions(options_.temperature, options_.top_p, options_.seed); + llama_generator->SetContextSize(options_.n_ctx); spdlog::info( "[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, " - "seed={})", + "n_ctx={}, seed={})", options_.model_path, options_.temperature, options_.top_p, - options_.seed); + options_.n_ctx, options_.seed); generator = std::move(llama_generator); } generator->Load(options_.model_path); diff --git a/pipeline/src/data_generation/data_downloader.cpp b/pipeline/src/data_generation/data_downloader.cpp index 802d90c..83861d4 100644 --- a/pipeline/src/data_generation/data_downloader.cpp +++ b/pipeline/src/data_generation/data_downloader.cpp @@ -25,15 +25,10 @@ std::string DataDownloader::DownloadCountriesDatabase( return cache_path; } - std::string short_commit = commit; - if (commit.length() > 7) { - short_commit = commit.substr(0, 7); - } - std::string url = "https://raw.githubusercontent.com/dr5hn/" "countries-states-cities-database/" + - short_commit + "/json/countries+states+cities.json"; + commit + "/json/countries+states+cities.json"; spdlog::info("[DataDownloader] Downloading: {}", url); diff --git a/pipeline/src/data_generation/llama/generate_brewery.cpp b/pipeline/src/data_generation/llama/generate_brewery.cpp index 31ea071..86f1c13 100644 --- a/pipeline/src/data_generation/llama/generate_brewery.cpp +++ b/pipeline/src/data_generation/llama/generate_brewery.cpp @@ -50,6 +50,14 @@ BreweryResult LlamaGenerator::GenerateBrewery( ? std::string(".") : std::string(". Regional context: ") + safe_region_context); + /** + * Store location context for retry prompts (without repeating full context) + */ + const std::string retry_location = + "Location: " + city_name + + (country_name.empty() ? std::string("") + : std::string(", ") + country_name); + /** * RETRY LOOP with validation and error correction * Attempts to generate valid brewery data up to 3 times, with feedback-based @@ -84,19 +92,16 @@ BreweryResult LlamaGenerator::GenerateBrewery( spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", attempt + 1, validation_error); - // Update prompt with error details to guide LLM toward correct output + // Update prompt with error details to guide LLM toward correct output. + // For retries, use a compact prompt format to avoid exceeding token + // limits. prompt = "Your previous response was invalid. Error: " + validation_error + "\nReturn ONLY valid JSON with this exact schema: " "{\"name\": \"string\", \"description\": \"string\"}." "\nDo not include markdown, comments, or extra keys." - "\n\nLocation: " + - city_name + - (country_name.empty() ? std::string("") - : std::string(", ") + country_name) + - (safe_region_context.empty() - ? std::string("") - : std::string("\nRegional context: ") + safe_region_context); + "\n\n" + + retry_location; } // All retry attempts exhausted: log failure and throw exception diff --git a/pipeline/src/data_generation/llama/helpers.cpp b/pipeline/src/data_generation/llama/helpers.cpp index 2a6bd2d..2bac497 100644 --- a/pipeline/src/data_generation/llama/helpers.cpp +++ b/pipeline/src/data_generation/llama/helpers.cpp @@ -147,7 +147,17 @@ std::pair ParseTwoLineResponse( std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); - if (!l.empty() && l.front() == '<' && low.back() == '>') continue; + // Filter known thinking tags like ..., but be conservative + // to avoid removing legitimate output. Only filter specific known + // patterns. + if (!l.empty() && l.front() == '<' && low.back() == '>') { + // Only filter if it's a known thinking tag: , , etc. + if (low.find("think") != std::string::npos || + low.find("reasoning") != std::string::npos || + low.find("reflect") != std::string::npos) { + continue; + } + } if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue; filtered.push_back(std::move(l)); } diff --git a/pipeline/src/data_generation/llama/infer.cpp b/pipeline/src/data_generation/llama/infer.cpp index 1a1c7d0..6036f90 100644 --- a/pipeline/src/data_generation/llama/infer.cpp +++ b/pipeline/src/data_generation/llama/infer.cpp @@ -186,5 +186,11 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, std::string output; for (const llama_token token : generated_tokens) AppendTokenPiecePublic(vocab, token, output); + + /** + * Advance seed for next generation to improve output diversity + */ + sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1; + return output; } diff --git a/pipeline/src/data_generation/llama/load.cpp b/pipeline/src/data_generation/llama/load.cpp index a20827d..52fa5c0 100644 --- a/pipeline/src/data_generation/llama/load.cpp +++ b/pipeline/src/data_generation/llama/load.cpp @@ -42,7 +42,7 @@ void LlamaGenerator::Load(const std::string& model_path) { } llama_context_params context_params = llama_context_default_params(); - context_params.n_ctx = 2048; + context_params.n_ctx = n_ctx_; context_ = llama_init_from_model(model_, context_params); if (context_ == nullptr) { diff --git a/pipeline/src/data_generation/llama/set_sampling_options.cpp b/pipeline/src/data_generation/llama/set_sampling_options.cpp index 7b9238c..c678863 100644 --- a/pipeline/src/data_generation/llama/set_sampling_options.cpp +++ b/pipeline/src/data_generation/llama/set_sampling_options.cpp @@ -48,3 +48,18 @@ void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, sampling_seed_ = (seed < 0) ? static_cast(LLAMA_DEFAULT_SEED) : static_cast(seed); } + +void LlamaGenerator::SetContextSize(uint32_t n_ctx) { + /** + * Validate context size: must be positive and reasonable for the model + */ + if (n_ctx == 0 || n_ctx > 32768) { + throw std::runtime_error( + "LlamaGenerator: context size must be in range [1, 32768]"); + } + + /** + * Store context size for use during model loading + */ + n_ctx_ = n_ctx; +} diff --git a/pipeline/src/database/database.cpp b/pipeline/src/database/database.cpp index 6242a22..7d22bf3 100644 --- a/pipeline/src/database/database.cpp +++ b/pipeline/src/database/database.cpp @@ -80,6 +80,16 @@ void SqliteDatabase::CommitTransaction() { } } +void SqliteDatabase::RollbackTransaction() { + std::lock_guard lock(db_mutex_); + char* err = nullptr; + if (sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, &err) != SQLITE_OK) { + std::string msg = err ? err : "unknown"; + sqlite3_free(err); + throw std::runtime_error("RollbackTransaction failed: " + msg); + } +} + void SqliteDatabase::InsertCountry(int id, const std::string& name, const std::string& iso2, const std::string& iso3) { @@ -96,9 +106,9 @@ void SqliteDatabase::InsertCountry(int id, const std::string& name, throw std::runtime_error("Failed to prepare country insert"); sqlite3_bind_int(stmt, 1, id); - sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_TRANSIENT); if (sqlite3_step(stmt) != SQLITE_DONE) { throw std::runtime_error("Failed to insert country"); @@ -123,8 +133,8 @@ void SqliteDatabase::InsertState(int id, int country_id, sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 2, country_id); - sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_TRANSIENT); if (sqlite3_step(stmt) != SQLITE_DONE) { throw std::runtime_error("Failed to insert state"); @@ -150,7 +160,7 @@ void SqliteDatabase::InsertCity(int id, int state_id, int country_id, sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 2, state_id); sqlite3_bind_int(stmt, 3, country_id); - sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_TRANSIENT); sqlite3_bind_double(stmt, 5, latitude); sqlite3_bind_double(stmt, 6, longitude); @@ -165,7 +175,8 @@ std::vector SqliteDatabase::QueryCities() { std::vector cities; sqlite3_stmt* stmt = nullptr; - const char* query = "SELECT id, name, country_id FROM cities ORDER BY name"; + const char* query = + "SELECT id, name, country_id FROM cities ORDER BY RANDOM()"; int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); if (rc != SQLITE_OK) { diff --git a/pipeline/src/json_handling/json_loader.cpp b/pipeline/src/json_handling/json_loader.cpp index 71875a1..c535358 100644 --- a/pipeline/src/json_handling/json_loader.cpp +++ b/pipeline/src/json_handling/json_loader.cpp @@ -11,7 +11,7 @@ void JsonLoader::LoadWorldCities(const std::string& json_path, constexpr size_t kBatchSize = 10000; auto startTime = std::chrono::high_resolution_clock::now(); - spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path); + spdlog::info("\nLoading {} (streaming Boost.JSON SAX)...", json_path); db.BeginTransaction(); bool transactionOpen = true; @@ -44,7 +44,8 @@ void JsonLoader::LoadWorldCities(const std::string& json_path, } } catch (...) { if (transactionOpen) { - db.CommitTransaction(); + db.RollbackTransaction(); + transactionOpen = false; } throw; } diff --git a/pipeline/src/main.cpp b/pipeline/src/main.cpp index cc05334..0f3a4a8 100644 --- a/pipeline/src/main.cpp +++ b/pipeline/src/main.cpp @@ -1,12 +1,12 @@ +#include + +#include #include #include -#include -#include - #include "biergarten_data_generator.h" -#include "web_client/curl_web_client.h" #include "database/database.h" +#include "web_client/curl_web_client.h" namespace po = boost::program_options; @@ -18,101 +18,117 @@ namespace po = boost::program_options; * @param options Output ApplicationOptions struct. * @return true if parsing succeeded and should proceed, false otherwise. */ -bool ParseArguments(int argc, char **argv, ApplicationOptions &options) { - // If no arguments provided, display usage and exit - if (argc == 1) { - std::cout << "Biergarten Pipeline - Geographic Data Pipeline with Brewery Generation\n\n"; - std::cout << "Usage: biergarten-pipeline [options]\n\n"; - std::cout << "Options:\n"; - std::cout << " --mocked Use mocked generator for brewery/user data\n"; - std::cout << " --model, -m PATH Path to LLM model file (gguf) for generation\n"; - std::cout << " --cache-dir, -c DIR Directory for cached JSON (default: /tmp)\n"; - std::cout << " --temperature TEMP LLM sampling temperature 0.0-1.0 (default: 0.8)\n"; - std::cout << " --top-p VALUE Nucleus sampling parameter 0.0-1.0 (default: 0.92)\n"; - std::cout << " --seed SEED Random seed: -1 for random (default: -1)\n"; - std::cout << " --help, -h Show this help message\n\n"; - std::cout << "Note: --mocked and --model are mutually exclusive. Exactly one must be provided.\n"; - std::cout << "Data source is always pinned to commit c5eb7772 (stable 2026-03-28).\n"; - return false; - } +bool ParseArguments(int argc, char** argv, ApplicationOptions& options) { + // If no arguments provided, display usage and exit + if (argc == 1) { + std::cout << "Biergarten Pipeline - Geographic Data Pipeline with " + "Brewery Generation\n\n"; + std::cout << "Usage: biergarten-pipeline [options]\n\n"; + std::cout << "Options:\n"; + std::cout << " --mocked Use mocked generator for " + "brewery/user data\n"; + std::cout << " --model, -m PATH Path to LLM model file (gguf) for " + "generation\n"; + std::cout << " --cache-dir, -c DIR Directory for cached JSON (default: " + "/tmp)\n"; + std::cout << " --temperature TEMP LLM sampling temperature 0.0-1.0 " + "(default: 0.8)\n"; + std::cout << " --top-p VALUE Nucleus sampling parameter 0.0-1.0 " + "(default: 0.92)\n"; + std::cout << " --n-ctx SIZE Context window size in tokens " + "(default: 2048)\n"; + std::cout << " --seed SEED Random seed: -1 for random " + "(default: -1)\n"; + std::cout << " --help, -h Show this help message\n\n"; + std::cout << "Note: --mocked and --model are mutually exclusive. Exactly " + "one must be provided.\n"; + std::cout << "Data source is always pinned to commit c5eb7772 (stable " + "2026-03-28).\n"; + return false; + } - po::options_description desc("Pipeline Options"); - desc.add_options()("help,h", "Produce help message")( - "mocked", po::bool_switch(), - "Use mocked generator for brewery/user data")( - "model,m", po::value()->default_value(""), - "Path to LLM model (gguf)")( - "cache-dir,c", po::value()->default_value("/tmp"), - "Directory for cached JSON")( - "temperature", po::value()->default_value(0.8f), - "Sampling temperature (higher = more random)")( - "top-p", po::value()->default_value(0.92f), - "Nucleus sampling top-p in (0,1] (higher = more random)")( - "seed", po::value()->default_value(-1), - "Sampler seed: -1 for random, otherwise non-negative integer"); + po::options_description desc("Pipeline Options"); + desc.add_options()("help,h", "Produce help message")( + "mocked", po::bool_switch(), + "Use mocked generator for brewery/user data")( + "model,m", po::value()->default_value(""), + "Path to LLM model (gguf)")( + "cache-dir,c", po::value()->default_value("/tmp"), + "Directory for cached JSON")( + "temperature", po::value()->default_value(0.8f), + "Sampling temperature (higher = more random)")( + "top-p", po::value()->default_value(0.92f), + "Nucleus sampling top-p in (0,1] (higher = more random)")( + "n-ctx", po::value()->default_value(2048), + "Context window size in tokens (1-32768)")( + "seed", po::value()->default_value(-1), + "Sampler seed: -1 for random, otherwise non-negative integer"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - po::notify(vm); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); - if (vm.count("help")) { - std::cout << desc << "\n"; - return false; - } + if (vm.count("help")) { + std::cout << desc << "\n"; + return false; + } - // Check for mutually exclusive --mocked and --model flags - bool use_mocked = vm["mocked"].as(); - std::string model_path = vm["model"].as(); + // Check for mutually exclusive --mocked and --model flags + bool use_mocked = vm["mocked"].as(); + std::string model_path = vm["model"].as(); - if (use_mocked && !model_path.empty()) { - spdlog::error("ERROR: --mocked and --model are mutually exclusive"); - return false; - } + if (use_mocked && !model_path.empty()) { + spdlog::error("ERROR: --mocked and --model are mutually exclusive"); + return false; + } - if (!use_mocked && model_path.empty()) { - spdlog::error("ERROR: Either --mocked or --model must be specified"); - return false; - } + if (!use_mocked && model_path.empty()) { + spdlog::error("ERROR: Either --mocked or --model must be specified"); + return false; + } - // Warn if sampling parameters are provided with --mocked - if (use_mocked) { - bool hasTemperature = vm["temperature"].defaulted() == false; - bool hasTopP = vm["top-p"].defaulted() == false; - bool hasSeed = vm["seed"].defaulted() == false; + // Warn if sampling parameters are provided with --mocked + if (use_mocked) { + bool hasTemperature = vm["temperature"].defaulted() == false; + bool hasTopP = vm["top-p"].defaulted() == false; + bool hasSeed = vm["seed"].defaulted() == false; - if (hasTemperature || hasTopP || hasSeed) { - spdlog::warn("WARNING: Sampling parameters (--temperature, --top-p, --seed) are ignored when using --mocked"); - } - } + if (hasTemperature || hasTopP || hasSeed) { + spdlog::warn( + "WARNING: Sampling parameters (--temperature, --top-p, --seed) " + "are ignored when using --mocked"); + } + } - options.use_mocked = use_mocked; - options.model_path = model_path; - options.cache_dir = vm["cache-dir"].as(); - options.temperature = vm["temperature"].as(); - options.top_p = vm["top-p"].as(); - options.seed = vm["seed"].as(); - // commit is always pinned to c5eb7772 + options.use_mocked = use_mocked; + options.model_path = model_path; + options.cache_dir = vm["cache-dir"].as(); + options.temperature = vm["temperature"].as(); + options.top_p = vm["top-p"].as(); + options.n_ctx = vm["n-ctx"].as(); + options.seed = vm["seed"].as(); + // commit is always pinned to c5eb7772 - return true; + return true; } -int main(int argc, char *argv[]) { - try { - const CurlGlobalState curl_state; +int main(int argc, char* argv[]) { + try { + const CurlGlobalState curl_state; - ApplicationOptions options; - if (!ParseArguments(argc, argv, options)) { - return 0; - } + ApplicationOptions options; + if (!ParseArguments(argc, argv, options)) { + return 0; + } - auto webClient = std::make_shared(); - SqliteDatabase database; + auto webClient = std::make_shared(); + SqliteDatabase database; - BiergartenDataGenerator generator(options, webClient, database); - return generator.Run(); + BiergartenDataGenerator generator(options, webClient, database); + return generator.Run(); - } catch (const std::exception &e) { - spdlog::error("ERROR: Application failed: {}", e.what()); - return 1; - } + } catch (const std::exception& e) { + spdlog::error("ERROR: Application failed: {}", e.what()); + return 1; + } } diff --git a/pipeline/src/wikipedia/wikipedia_service.cpp b/pipeline/src/wikipedia/wikipedia_service.cpp index ef851bb..60b9743 100644 --- a/pipeline/src/wikipedia/wikipedia_service.cpp +++ b/pipeline/src/wikipedia/wikipedia_service.cpp @@ -11,7 +11,7 @@ std::string WikipediaService::FetchExtract(std::string_view query) { const std::string encoded = client_->UrlEncode(std::string(query)); const std::string url = "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + - "&prop=extracts&explaintext=true&format=json"; + "&prop=extracts&explaintext=1&format=json"; const std::string body = client_->Get(url); @@ -19,16 +19,27 @@ std::string WikipediaService::FetchExtract(std::string_view query) { boost::json::value doc = boost::json::parse(body, ec); if (!ec && doc.is_object()) { - auto& pages = doc.at("query").at("pages").get_object(); - if (!pages.empty()) { - auto& page = pages.begin()->value().get_object(); - if (page.contains("extract") && page.at("extract").is_string()) { - std::string extract(page.at("extract").as_string().c_str()); - spdlog::debug("WikipediaService fetched {} chars for '{}'", - extract.size(), query); - return extract; + try { + auto& pages = doc.at("query").at("pages").get_object(); + if (!pages.empty()) { + auto& page = pages.begin()->value().get_object(); + if (page.contains("extract") && page.at("extract").is_string()) { + std::string extract(page.at("extract").as_string().c_str()); + spdlog::debug("WikipediaService fetched {} chars for '{}'", + extract.size(), query); + return extract; + } } + } catch (const std::exception& e) { + spdlog::warn( + "WikipediaService: failed to parse response structure for '{}': " + "{}", + query, e.what()); + return {}; } + } else if (ec) { + spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query, + ec.message()); } return {};