From a057b9197f546590ccd977df37c1975e806c73ab Mon Sep 17 00:00:00 2001 From: Aaron Po Date: Wed, 13 May 2026 22:04:48 -0400 Subject: [PATCH] Add location count to application options and as a cli arg --- tooling/pipeline/includes/biergarten_data_generator.h | 9 ++++++--- tooling/pipeline/includes/data_model/models.h | 4 ++++ .../pipeline/src/application_options/parse_arguments.cc | 6 ++++-- .../biergarten_data_generator.cc | 6 ++++-- .../query_cities_with_countries.cc | 6 +++--- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/tooling/pipeline/includes/biergarten_data_generator.h b/tooling/pipeline/includes/biergarten_data_generator.h index 1396213..e74ba02 100644 --- a/tooling/pipeline/includes/biergarten_data_generator.h +++ b/tooling/pipeline/includes/biergarten_data_generator.h @@ -12,8 +12,8 @@ #include "data_generation/data_generator.h" #include "data_model/generated_models.h" -#include "services/enrichment/enrichment_service.h" #include "services/database/export_service.h" +#include "services/enrichment/enrichment_service.h" /** * @brief Main data generator class for the Biergarten pipeline. @@ -32,7 +32,8 @@ class BiergartenDataGenerator { */ BiergartenDataGenerator(std::unique_ptr context_service, std::unique_ptr generator, - std::unique_ptr exporter); + std::unique_ptr exporter, + const ApplicationOptions& application_options); /** * @brief Run the data generation pipeline. @@ -56,12 +57,14 @@ class BiergartenDataGenerator { /// @brief Storage backend for generated brewery records. std::unique_ptr exporter_; + const ApplicationOptions application_options_; + /** * @brief Load locations from JSON and sample cities. * * @return Vector of sampled locations capped at 50 entries. */ - static std::vector QueryCitiesWithCountries(); + std::vector QueryCitiesWithCountries(); /** * @brief Generate breweries for enriched cities. diff --git a/tooling/pipeline/includes/data_model/models.h b/tooling/pipeline/includes/data_model/models.h index c046557..9346b01 100644 --- a/tooling/pipeline/includes/data_model/models.h +++ b/tooling/pipeline/includes/data_model/models.h @@ -118,6 +118,10 @@ struct PipelineOptions { /// @brief Path for application logs. std::filesystem::path log_path; + + /// @brief Number of locations to sample from the dataset + /// More locations -> more users/more breweries + uint32_t location_count; }; /** diff --git a/tooling/pipeline/src/application_options/parse_arguments.cc b/tooling/pipeline/src/application_options/parse_arguments.cc index e568bd9..b2995d1 100644 --- a/tooling/pipeline/src/application_options/parse_arguments.cc +++ b/tooling/pipeline/src/application_options/parse_arguments.cc @@ -31,7 +31,7 @@ std::optional ParseArguments(const int argc, char** argv) { opt("seed", prog_opts::value()->default_value(sampling_defaults.seed), "Sampler seed: -1 for random, otherwise non-negative integer"); opt("n-gpu-layers", prog_opts::value()->default_value(0), - "Number of layers to offload to GPU"); + "Number of layers to offload to GPU"); }; // --mocked and --model are mutually exclusive; validation is enforced below @@ -52,7 +52,7 @@ std::optional ParseArguments(const int argc, char** argv) { opt("prompt-dir", prog_opts::value()->default_value(""), "Directory containing named prompt files (e.g. BREWERY_GENERATION.md)." " Required when not using --mocked."); - + opt("location-count", prog_opts::value()->default_value(10)); }; add_sampling_options(); @@ -85,6 +85,8 @@ std::optional ParseArguments(const int argc, char** argv) { options.pipeline.output_path = var_map["output"].as(); options.pipeline.log_path = var_map["log-path"].as(); options.pipeline.prompt_dir = var_map["prompt-dir"].as(); + options.pipeline.location_count = + var_map["location-count"].as(); const bool use_mocked = var_map["mocked"].as(); const std::string model_path = var_map["model"].as(); diff --git a/tooling/pipeline/src/biergarten_data_generator/biergarten_data_generator.cc b/tooling/pipeline/src/biergarten_data_generator/biergarten_data_generator.cc index 033795d..71875b3 100644 --- a/tooling/pipeline/src/biergarten_data_generator/biergarten_data_generator.cc +++ b/tooling/pipeline/src/biergarten_data_generator/biergarten_data_generator.cc @@ -10,7 +10,9 @@ BiergartenDataGenerator::BiergartenDataGenerator( std::unique_ptr context_service, std::unique_ptr generator, - std::unique_ptr exporter) + std::unique_ptr exporter, + const ApplicationOptions &app_options) : context_service_(std::move(context_service)), generator_(std::move(generator)), - exporter_(std::move(exporter)) {} + exporter_(std::move(exporter)), + application_options_(app_options) {} diff --git a/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc b/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc index 2427a15..c17654f 100644 --- a/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc +++ b/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc @@ -13,8 +13,6 @@ #include "biergarten_data_generator.h" #include "json_handling/json_loader.h" -static constexpr size_t kBreweryAmount = 40; - std::vector BiergartenDataGenerator::QueryCitiesWithCountries() { spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); @@ -23,7 +21,9 @@ std::vector BiergartenDataGenerator::QueryCitiesWithCountries() { auto all_locations = JsonLoader::LoadLocations(locations_path); spdlog::info(" Locations available: {}", all_locations.size()); - const size_t sample_count = std::min(kBreweryAmount, all_locations.size()); + const size_t sample_count = std::min( + static_cast(application_options_.pipeline.location_count), + all_locations.size()); const auto sample_count_signed = static_cast>(