mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Add location count to application options and as a cli arg
This commit is contained in:
@@ -12,8 +12,8 @@
|
|||||||
|
|
||||||
#include "data_generation/data_generator.h"
|
#include "data_generation/data_generator.h"
|
||||||
#include "data_model/generated_models.h"
|
#include "data_model/generated_models.h"
|
||||||
#include "services/enrichment/enrichment_service.h"
|
|
||||||
#include "services/database/export_service.h"
|
#include "services/database/export_service.h"
|
||||||
|
#include "services/enrichment/enrichment_service.h"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Main data generator class for the Biergarten pipeline.
|
* @brief Main data generator class for the Biergarten pipeline.
|
||||||
@@ -32,7 +32,8 @@ class BiergartenDataGenerator {
|
|||||||
*/
|
*/
|
||||||
BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service,
|
BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service,
|
||||||
std::unique_ptr<DataGenerator> generator,
|
std::unique_ptr<DataGenerator> generator,
|
||||||
std::unique_ptr<IExportService> exporter);
|
std::unique_ptr<IExportService> exporter,
|
||||||
|
const ApplicationOptions& application_options);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Run the data generation pipeline.
|
* @brief Run the data generation pipeline.
|
||||||
@@ -56,12 +57,14 @@ class BiergartenDataGenerator {
|
|||||||
/// @brief Storage backend for generated brewery records.
|
/// @brief Storage backend for generated brewery records.
|
||||||
std::unique_ptr<IExportService> exporter_;
|
std::unique_ptr<IExportService> exporter_;
|
||||||
|
|
||||||
|
const ApplicationOptions application_options_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Load locations from JSON and sample cities.
|
* @brief Load locations from JSON and sample cities.
|
||||||
*
|
*
|
||||||
* @return Vector of sampled locations capped at 50 entries.
|
* @return Vector of sampled locations capped at 50 entries.
|
||||||
*/
|
*/
|
||||||
static std::vector<Location> QueryCitiesWithCountries();
|
std::vector<Location> QueryCitiesWithCountries();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Generate breweries for enriched cities.
|
* @brief Generate breweries for enriched cities.
|
||||||
|
|||||||
@@ -118,6 +118,10 @@ struct PipelineOptions {
|
|||||||
|
|
||||||
/// @brief Path for application logs.
|
/// @brief Path for application logs.
|
||||||
std::filesystem::path log_path;
|
std::filesystem::path log_path;
|
||||||
|
|
||||||
|
/// @brief Number of locations to sample from the dataset
|
||||||
|
/// More locations -> more users/more breweries
|
||||||
|
uint32_t location_count;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
|
|||||||
opt("seed", prog_opts::value<int>()->default_value(sampling_defaults.seed),
|
opt("seed", prog_opts::value<int>()->default_value(sampling_defaults.seed),
|
||||||
"Sampler seed: -1 for random, otherwise non-negative integer");
|
"Sampler seed: -1 for random, otherwise non-negative integer");
|
||||||
opt("n-gpu-layers", prog_opts::value<int>()->default_value(0),
|
opt("n-gpu-layers", prog_opts::value<int>()->default_value(0),
|
||||||
"Number of layers to offload to GPU");
|
"Number of layers to offload to GPU");
|
||||||
};
|
};
|
||||||
|
|
||||||
// --mocked and --model are mutually exclusive; validation is enforced below
|
// --mocked and --model are mutually exclusive; validation is enforced below
|
||||||
@@ -52,7 +52,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
|
|||||||
opt("prompt-dir", prog_opts::value<std::string>()->default_value(""),
|
opt("prompt-dir", prog_opts::value<std::string>()->default_value(""),
|
||||||
"Directory containing named prompt files (e.g. BREWERY_GENERATION.md)."
|
"Directory containing named prompt files (e.g. BREWERY_GENERATION.md)."
|
||||||
" Required when not using --mocked.");
|
" Required when not using --mocked.");
|
||||||
|
opt("location-count", prog_opts::value<uint32_t>()->default_value(10));
|
||||||
};
|
};
|
||||||
|
|
||||||
add_sampling_options();
|
add_sampling_options();
|
||||||
@@ -85,6 +85,8 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
|
|||||||
options.pipeline.output_path = var_map["output"].as<std::string>();
|
options.pipeline.output_path = var_map["output"].as<std::string>();
|
||||||
options.pipeline.log_path = var_map["log-path"].as<std::string>();
|
options.pipeline.log_path = var_map["log-path"].as<std::string>();
|
||||||
options.pipeline.prompt_dir = var_map["prompt-dir"].as<std::string>();
|
options.pipeline.prompt_dir = var_map["prompt-dir"].as<std::string>();
|
||||||
|
options.pipeline.location_count =
|
||||||
|
var_map["location-count"].as<uint32_t>();
|
||||||
|
|
||||||
const bool use_mocked = var_map["mocked"].as<bool>();
|
const bool use_mocked = var_map["mocked"].as<bool>();
|
||||||
const std::string model_path = var_map["model"].as<std::string>();
|
const std::string model_path = var_map["model"].as<std::string>();
|
||||||
|
|||||||
@@ -10,7 +10,9 @@
|
|||||||
BiergartenDataGenerator::BiergartenDataGenerator(
|
BiergartenDataGenerator::BiergartenDataGenerator(
|
||||||
std::unique_ptr<IEnrichmentService> context_service,
|
std::unique_ptr<IEnrichmentService> context_service,
|
||||||
std::unique_ptr<DataGenerator> generator,
|
std::unique_ptr<DataGenerator> generator,
|
||||||
std::unique_ptr<IExportService> exporter)
|
std::unique_ptr<IExportService> exporter,
|
||||||
|
const ApplicationOptions &app_options)
|
||||||
: context_service_(std::move(context_service)),
|
: context_service_(std::move(context_service)),
|
||||||
generator_(std::move(generator)),
|
generator_(std::move(generator)),
|
||||||
exporter_(std::move(exporter)) {}
|
exporter_(std::move(exporter)),
|
||||||
|
application_options_(app_options) {}
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
#include "biergarten_data_generator.h"
|
#include "biergarten_data_generator.h"
|
||||||
#include "json_handling/json_loader.h"
|
#include "json_handling/json_loader.h"
|
||||||
|
|
||||||
static constexpr size_t kBreweryAmount = 40;
|
|
||||||
|
|
||||||
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
||||||
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
||||||
|
|
||||||
@@ -23,7 +21,9 @@ std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
|||||||
auto all_locations = JsonLoader::LoadLocations(locations_path);
|
auto all_locations = JsonLoader::LoadLocations(locations_path);
|
||||||
spdlog::info(" Locations available: {}", all_locations.size());
|
spdlog::info(" Locations available: {}", all_locations.size());
|
||||||
|
|
||||||
const size_t sample_count = std::min(kBreweryAmount, all_locations.size());
|
const size_t sample_count = std::min(
|
||||||
|
static_cast<size_t>(application_options_.pipeline.location_count),
|
||||||
|
all_locations.size());
|
||||||
|
|
||||||
const auto sample_count_signed =
|
const auto sample_count_signed =
|
||||||
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
|
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
|
||||||
|
|||||||
Reference in New Issue
Block a user