diff --git a/pipeline/includes/biergarten_data_generator.h b/pipeline/includes/biergarten_data_generator.h index 9371031..a41575b 100644 --- a/pipeline/includes/biergarten_data_generator.h +++ b/pipeline/includes/biergarten_data_generator.h @@ -3,114 +3,151 @@ #include #include -#include #include +#include #include "data_generation/data_generator.h" #include "database/database.h" #include "web_client/web_client.h" #include "wikipedia/wikipedia_service.h" - /** * @brief Program options for the Biergarten pipeline application. */ struct ApplicationOptions { - /// @brief Path to the LLM model file (gguf format); mutually exclusive with use_mocked. - std::string model_path; + /// @brief Path to the LLM model file (gguf format); mutually exclusive with + /// use_mocked. + std::string model_path; - /// @brief Use mocked generator instead of LLM; mutually exclusive with model_path. - bool use_mocked = false; + /// @brief Use mocked generator instead of LLM; mutually exclusive with + /// model_path. + bool use_mocked = false; - /// @brief Directory for cached JSON and database files. - std::string cache_dir; + /// @brief Directory for cached JSON and database files. + std::string cache_dir; - /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). - float temperature = 0.8f; + /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). + float temperature = 0.8f; - /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more random). - float top_p = 0.92f; + /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more + /// random). + float top_p = 0.92f; - /// @brief Random seed for sampling (-1 for random, otherwise non-negative). - int seed = -1; + /// @brief Random seed for sampling (-1 for random, otherwise non-negative). + int seed = -1; - /// @brief Git commit hash for database consistency (always pinned to c5eb7772). - std::string commit = "c5eb7772"; + /// @brief Git commit hash for database consistency (always pinned to + /// c5eb7772). + std::string commit = "c5eb7772"; }; -#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_ - - /** * @brief Main data generator class for the Biergarten pipeline. * * This class encapsulates the core logic for generating brewery data. - * It handles database initialization, data loading/downloading, and brewery generation. + * It handles database initialization, data loading/downloading, and brewery + * generation. */ class BiergartenDataGenerator { -public: - /** - * @brief Construct a BiergartenDataGenerator with injected dependencies. - * - * @param options Application configuration options. - * @param web_client HTTP client for downloading data. - * @param database SQLite database instance. - */ - BiergartenDataGenerator(const ApplicationOptions &options, - std::shared_ptr web_client, - SqliteDatabase &database); + public: + /** + * @brief Construct a BiergartenDataGenerator with injected dependencies. + * + * @param options Application configuration options. + * @param web_client HTTP client for downloading data. + * @param database SQLite database instance. + */ + BiergartenDataGenerator(const ApplicationOptions& options, + std::shared_ptr web_client, + SqliteDatabase& database); - /** - * @brief Run the data generation pipeline. - * - * Performs the following steps: - * 1. Initialize database - * 2. Download geographic data if needed - * 3. Initialize the generator (LLM or Mock) - * 4. Generate brewery data for sample cities - * - * @return 0 on success, 1 on failure. - */ - int Run(); + /** + * @brief Run the data generation pipeline. + * + * Performs the following steps: + * 1. Initialize database + * 2. Download geographic data if needed + * 3. Initialize the generator (LLM or Mock) + * 4. Generate brewery data for sample cities + * + * @return 0 on success, 1 on failure. + */ + int Run(); -private: - /// @brief Immutable application options. - const ApplicationOptions options_; + private: + /// @brief Immutable application options. + const ApplicationOptions options_; - /// @brief Shared HTTP client dependency. - std::shared_ptr webClient_; + /// @brief Shared HTTP client dependency. + std::shared_ptr webClient_; - /// @brief Database dependency. - SqliteDatabase &database_; + /// @brief Database dependency. + SqliteDatabase& database_; - /** - * @brief Initialize the data generator based on options. - * - * Creates either a MockGenerator (if no model path) or LlamaGenerator. - * - * @return A unique_ptr to the initialized generator. - */ - std::unique_ptr InitializeGenerator(); + /** + * @brief Enriched city data with Wikipedia context. + */ + struct EnrichedCity { + int city_id; + std::string city_name; + std::string country_name; + std::string region_context; + }; - /** - * @brief Download and load geographic data if not cached. - */ - void LoadGeographicData(); + /** + * @brief Initialize the data generator based on options. + * + * Creates either a MockGenerator (if no model path) or LlamaGenerator. + * + * @return A unique_ptr to the initialized generator. + */ + std::unique_ptr InitializeGenerator(); - /** - * @brief Generate sample breweries for demonstration. - */ - void GenerateSampleBreweries(); + /** + * @brief Download and load geographic data if not cached. + */ + void LoadGeographicData(); - /** - * @brief Helper struct to store generated brewery data. - */ - struct GeneratedBrewery { - int city_id; - std::string city_name; - BreweryResult brewery; - }; + /** + * @brief Query cities from database and build country name map. + * + * @return Vector of (City, country_name) pairs capped at 30 entries. + */ + std::vector> QueryCitiesWithCountries(); - /// @brief Stores generated brewery data. - std::vector generatedBreweries_; + /** + * @brief Enrich cities with Wikipedia summaries. + * + * @param cities Vector of (City, country_name) pairs. + * @return Vector of enriched city data with context. + */ + std::vector EnrichWithWikipedia( + const std::vector>& cities); + + /** + * @brief Generate breweries for enriched cities. + * + * @param generator The data generator instance. + * @param cities Vector of enriched city data. + */ + void GenerateBreweries(DataGenerator& generator, + const std::vector& cities); + + /** + * @brief Log the generated brewery results. + */ + void LogResults() const; + + /** + * @brief Helper struct to store generated brewery data. + */ + struct GeneratedBrewery { + int city_id; + std::string city_name; + BreweryResult brewery; + }; + + /// @brief Stores generated brewery data. + std::vector generatedBreweries_; }; +#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_ diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index 31d29b1..a996574 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -31,6 +31,9 @@ class LlamaGenerator final : public DataGenerator { std::string Infer(const std::string& system_prompt, const std::string& prompt, int max_tokens = 10000); + std::string InferFormatted(const std::string& formatted_prompt, + int max_tokens = 10000); + llama_model* model_ = nullptr; llama_context* context_ = nullptr; float sampling_temperature_ = 0.8f; diff --git a/pipeline/src/biergarten_data_generator.cpp b/pipeline/src/biergarten_data_generator.cpp index c2d2389..11dc4d6 100644 --- a/pipeline/src/biergarten_data_generator.cpp +++ b/pipeline/src/biergarten_data_generator.cpp @@ -1,132 +1,157 @@ #include "biergarten_data_generator.h" +#include + #include #include #include -#include - #include "data_generation/data_downloader.h" -#include "json_handling/json_loader.h" #include "data_generation/llama_generator.h" #include "data_generation/mock_generator.h" +#include "json_handling/json_loader.h" #include "wikipedia/wikipedia_service.h" BiergartenDataGenerator::BiergartenDataGenerator( - const ApplicationOptions &options, - std::shared_ptr web_client, - SqliteDatabase &database) + const ApplicationOptions& options, std::shared_ptr web_client, + SqliteDatabase& database) : options_(options), webClient_(web_client), database_(database) {} std::unique_ptr BiergartenDataGenerator::InitializeGenerator() { - spdlog::info("Initializing brewery generator..."); + spdlog::info("Initializing brewery generator..."); - std::unique_ptr generator; - if (options_.model_path.empty()) { - generator = std::make_unique(); - spdlog::info("[Generator] Using MockGenerator (no model path provided)"); - } else { - auto llama_generator = std::make_unique(); - llama_generator->SetSamplingOptions(options_.temperature, options_.top_p, - options_.seed); - spdlog::info( - "[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, " - "seed={})", - options_.model_path, options_.temperature, options_.top_p, - options_.seed); - generator = std::move(llama_generator); - } - generator->Load(options_.model_path); + std::unique_ptr generator; + if (options_.model_path.empty()) { + generator = std::make_unique(); + spdlog::info("[Generator] Using MockGenerator (no model path provided)"); + } else { + auto llama_generator = std::make_unique(); + llama_generator->SetSamplingOptions(options_.temperature, options_.top_p, + options_.seed); + spdlog::info( + "[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, " + "seed={})", + options_.model_path, options_.temperature, options_.top_p, + options_.seed); + generator = std::move(llama_generator); + } + generator->Load(options_.model_path); - return generator; + return generator; } void BiergartenDataGenerator::LoadGeographicData() { - std::string json_path = options_.cache_dir + "/countries+states+cities.json"; - std::string db_path = options_.cache_dir + "/biergarten-pipeline.db"; + std::string json_path = options_.cache_dir + "/countries+states+cities.json"; + std::string db_path = options_.cache_dir + "/biergarten-pipeline.db"; - bool has_json_cache = std::filesystem::exists(json_path); - bool has_db_cache = std::filesystem::exists(db_path); + bool has_json_cache = std::filesystem::exists(json_path); + bool has_db_cache = std::filesystem::exists(db_path); - spdlog::info("Initializing SQLite database at {}...", db_path); - database_.Initialize(db_path); + spdlog::info("Initializing SQLite database at {}...", db_path); + database_.Initialize(db_path); - if (has_db_cache && has_json_cache) { - spdlog::info("[Pipeline] Cache hit: skipping download and parse"); - } else { - spdlog::info("\n[Pipeline] Downloading geographic data from GitHub..."); - DataDownloader downloader(webClient_); - downloader.DownloadCountriesDatabase(json_path, options_.commit); + if (has_db_cache && has_json_cache) { + spdlog::info("[Pipeline] Cache hit: skipping download and parse"); + } else { + spdlog::info("\n[Pipeline] Downloading geographic data from GitHub..."); + DataDownloader downloader(webClient_); + downloader.DownloadCountriesDatabase(json_path, options_.commit); - JsonLoader::LoadWorldCities(json_path, database_); - } + JsonLoader::LoadWorldCities(json_path, database_); + } } -void BiergartenDataGenerator::GenerateSampleBreweries() { - auto generator = InitializeGenerator(); - WikipediaService wikipedia_service(webClient_); +std::vector> +BiergartenDataGenerator::QueryCitiesWithCountries() { + spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); - spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); + auto cities = database_.QueryCities(); - auto countries = database_.QueryCountries(50); - auto states = database_.QueryStates(50); - auto cities = database_.QueryCities(); + // Build a quick map of country id -> name for per-city lookups. + auto all_countries = database_.QueryCountries(0); + std::unordered_map country_map; + for (const auto& c : all_countries) { + country_map[c.id] = c.name; + } - // Build a quick map of country id -> name for per-city lookups. - auto all_countries = database_.QueryCountries(0); - std::unordered_map country_map; - for (const auto &c : all_countries) - country_map[c.id] = c.name; + spdlog::info("\nTotal records loaded:"); + spdlog::info(" Countries: {}", database_.QueryCountries(0).size()); + spdlog::info(" States: {}", database_.QueryStates(0).size()); + spdlog::info(" Cities: {}", cities.size()); - spdlog::info("\nTotal records loaded:"); - spdlog::info(" Countries: {}", database_.QueryCountries(0).size()); - spdlog::info(" States: {}", database_.QueryStates(0).size()); - spdlog::info(" Cities: {}", cities.size()); + // Cap at 30 entries. + const size_t sample_count = std::min(size_t(30), cities.size()); + std::vector> result; - generatedBreweries_.clear(); - const size_t sample_count = std::min(size_t(30), cities.size()); + for (size_t i = 0; i < sample_count; i++) { + const auto& city = cities[i]; + std::string country_name; + const auto country_it = country_map.find(city.country_id); + if (country_it != country_map.end()) { + country_name = country_it->second; + } + result.push_back({city, country_name}); + } - spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); - for (size_t i = 0; i < sample_count; i++) { - const auto &city = cities[i]; - const int city_id = city.id; - const std::string city_name = city.name; + return result; +} - std::string local_country; - const auto country_it = country_map.find(city.country_id); - if (country_it != country_map.end()) { - local_country = country_it->second; - } +std::vector +BiergartenDataGenerator::EnrichWithWikipedia( + const std::vector>& cities) { + WikipediaService wikipedia_service(webClient_); + std::vector enriched; - const std::string region_context = - wikipedia_service.GetSummary(city_name, local_country); - spdlog::debug("[Pipeline] Region context for {}: {}", city_name, - region_context); + for (const auto& [city, country_name] : cities) { + const std::string region_context = + wikipedia_service.GetSummary(city.name, country_name); + spdlog::debug("[Pipeline] Region context for {}: {}", city.name, + region_context); - auto brewery = - generator->GenerateBrewery(city_name, local_country, region_context); - generatedBreweries_.push_back({city_id, city_name, brewery}); - } + enriched.push_back({city.id, city.name, country_name, region_context}); + } - spdlog::info("\n=== GENERATED DATA DUMP ==="); - for (size_t i = 0; i < generatedBreweries_.size(); i++) { - const auto &entry = generatedBreweries_[i]; - spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.city_id, - entry.city_name); - spdlog::info(" brewery_name=\"{}\"", entry.brewery.name); - spdlog::info(" brewery_description=\"{}\"", entry.brewery.description); - } + return enriched; +} + +void BiergartenDataGenerator::GenerateBreweries( + DataGenerator& generator, const std::vector& cities) { + spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); + generatedBreweries_.clear(); + + for (const auto& enriched_city : cities) { + auto brewery = generator.GenerateBrewery(enriched_city.city_name, + enriched_city.country_name, + enriched_city.region_context); + generatedBreweries_.push_back( + {enriched_city.city_id, enriched_city.city_name, brewery}); + } +} + +void BiergartenDataGenerator::LogResults() const { + spdlog::info("\n=== GENERATED DATA DUMP ==="); + for (size_t i = 0; i < generatedBreweries_.size(); i++) { + const auto& entry = generatedBreweries_[i]; + spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.city_id, + entry.city_name); + spdlog::info(" brewery_name=\"{}\"", entry.brewery.name); + spdlog::info(" brewery_description=\"{}\"", entry.brewery.description); + } } int BiergartenDataGenerator::Run() { - try { - LoadGeographicData(); - GenerateSampleBreweries(); + try { + LoadGeographicData(); + auto generator = InitializeGenerator(); + auto cities = QueryCitiesWithCountries(); + auto enriched = EnrichWithWikipedia(cities); + GenerateBreweries(*generator, enriched); + LogResults(); - spdlog::info("\nOK: Pipeline completed successfully"); - return 0; - } catch (const std::exception &e) { - spdlog::error("ERROR: Pipeline failed: {}", e.what()); - return 1; - } + spdlog::info("\nOK: Pipeline completed successfully"); + return 0; + } catch (const std::exception& e) { + spdlog::error("ERROR: Pipeline failed: {}", e.what()); + return 1; + } } diff --git a/pipeline/src/data_generation/llama/infer.cpp b/pipeline/src/data_generation/llama/infer.cpp index c938f87..ae1b786 100644 --- a/pipeline/src/data_generation/llama/infer.cpp +++ b/pipeline/src/data_generation/llama/infer.cpp @@ -11,100 +11,17 @@ #include "llama.h" std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) { - if (model_ == nullptr || context_ == nullptr) - throw std::runtime_error("LlamaGenerator: model not loaded"); - - const llama_vocab* vocab = llama_model_get_vocab(model_); - if (vocab == nullptr) - throw std::runtime_error("LlamaGenerator: vocab unavailable"); - - llama_memory_clear(llama_get_memory(context_), true); - - const std::string formatted_prompt = ToChatPromptPublic(model_, prompt); - - std::vector prompt_tokens(formatted_prompt.size() + 8); - int32_t token_count = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), prompt_tokens.data(), - static_cast(prompt_tokens.size()), true, true); - - if (token_count < 0) { - prompt_tokens.resize(static_cast(-token_count)); - token_count = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), prompt_tokens.data(), - static_cast(prompt_tokens.size()), true, true); - } - - if (token_count < 0) - throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); - - const int32_t n_ctx = static_cast(llama_n_ctx(context_)); - const int32_t n_batch = static_cast(llama_n_batch(context_)); - if (n_ctx <= 1 || n_batch <= 0) { - throw std::runtime_error("LlamaGenerator: invalid context or batch size"); - } - - const int32_t effective_max_tokens = - std::max(1, std::min(max_tokens, n_ctx - 1)); - int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); - prompt_budget = std::max(1, prompt_budget); - - prompt_tokens.resize(static_cast(token_count)); - if (token_count > prompt_budget) { - spdlog::warn( - "LlamaGenerator: prompt too long ({} tokens), truncating to {} " - "tokens " - "to fit n_batch/n_ctx limits", - token_count, prompt_budget); - prompt_tokens.resize(static_cast(prompt_budget)); - token_count = prompt_budget; - } - - const llama_batch prompt_batch = llama_batch_get_one( - prompt_tokens.data(), static_cast(prompt_tokens.size())); - if (llama_decode(context_, prompt_batch) != 0) - throw std::runtime_error("LlamaGenerator: prompt decode failed"); - - llama_sampler_chain_params sampler_params = - llama_sampler_chain_default_params(); - using SamplerPtr = - std::unique_ptr; - SamplerPtr sampler(llama_sampler_chain_init(sampler_params), - &llama_sampler_free); - if (!sampler) - throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_temp(sampling_temperature_)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_top_p(sampling_top_p_, 1)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_dist(sampling_seed_)); - - std::vector generated_tokens; - generated_tokens.reserve(static_cast(max_tokens)); - - for (int i = 0; i < effective_max_tokens; ++i) { - const llama_token next = - llama_sampler_sample(sampler.get(), context_, -1); - if (llama_vocab_is_eog(vocab, next)) break; - generated_tokens.push_back(next); - llama_token token = next; - const llama_batch one_token_batch = llama_batch_get_one(&token, 1); - if (llama_decode(context_, one_token_batch) != 0) - throw std::runtime_error( - "LlamaGenerator: decode failed during generation"); - } - - std::string output; - for (const llama_token token : generated_tokens) - AppendTokenPiecePublic(vocab, token, output); - return output; + return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens); } std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, int max_tokens) { + return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt), + max_tokens); +} + +std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, + int max_tokens) { if (model_ == nullptr || context_ == nullptr) throw std::runtime_error("LlamaGenerator: model not loaded"); @@ -114,9 +31,6 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt, llama_memory_clear(llama_get_memory(context_), true); - const std::string formatted_prompt = - ToChatPromptPublic(model_, system_prompt, prompt); - std::vector prompt_tokens(formatted_prompt.size() + 8); int32_t token_count = llama_tokenize( vocab, formatted_prompt.c_str(), @@ -136,9 +50,8 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt, const int32_t n_ctx = static_cast(llama_n_ctx(context_)); const int32_t n_batch = static_cast(llama_n_batch(context_)); - if (n_ctx <= 1 || n_batch <= 0) { + if (n_ctx <= 1 || n_batch <= 0) throw std::runtime_error("LlamaGenerator: invalid context or batch size"); - } const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); @@ -149,8 +62,7 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt, if (token_count > prompt_budget) { spdlog::warn( "LlamaGenerator: prompt too long ({} tokens), truncating to {} " - "tokens " - "to fit n_batch/n_ctx limits", + "tokens to fit n_batch/n_ctx limits", token_count, prompt_budget); prompt_tokens.resize(static_cast(prompt_budget)); token_count = prompt_budget; @@ -178,7 +90,7 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt, llama_sampler_init_dist(sampling_seed_)); std::vector generated_tokens; - generated_tokens.reserve(static_cast(max_tokens)); + generated_tokens.reserve(static_cast(effective_max_tokens)); for (int i = 0; i < effective_max_tokens; ++i) { const llama_token next =