From 01849062d58a90809e019ce76bf8109c8258ae80 Mon Sep 17 00:00:00 2001 From: Aaron Po Date: Fri, 1 May 2026 00:40:21 -0400 Subject: [PATCH] Refactor ApplicationOptions to separate config concerns --- docs/pipeline/diagrams/planned/class.puml | 31 +++---- tooling/pipeline/CMakeLists.txt | 18 ++-- .../includes/data_model/application_options.h | 56 ++++++++---- .../includes/services/export_service.h | 2 + .../includes/services/sqlite_export_service.h | 4 +- .../data_generation/llama/llama_generator.cc | 25 +++--- tooling/pipeline/src/main.cc | 87 +++++++++++-------- .../services/sqlite/build_database_path.cc | 5 +- .../services/sqlite/sqlite_export_service.cc | 7 +- 9 files changed, 142 insertions(+), 93 deletions(-) diff --git a/docs/pipeline/diagrams/planned/class.puml b/docs/pipeline/diagrams/planned/class.puml index 3b775cf..fd950c9 100644 --- a/docs/pipeline/diagrams/planned/class.puml +++ b/docs/pipeline/diagrams/planned/class.puml @@ -141,37 +141,38 @@ package "Domain: Models" { LocationContext *-- Completeness } - -package "Domain: Application Configuration"{ +@startuml +package "Domain: Application Configuration" { class SamplingOptions { - + temperature : float = 1.0F - + top_p : float = 0.95F - + top_k : uint32_t = 64 - + n_ctx : uint32_t = 8192 - + seed : int = -1 + + temperature: float = 1.0F + + top_p: float = 0.95F + + top_k: uint32_t = 64 + + n_ctx: uint32_t = 8192 + + seed: int = -1 } class GeneratorOptions { - + model_path : std::filesystem::path - + use_mocked : bool = false - + sampling : SamplingOptions + + model_path: std::filesystem::path + + use_mocked: bool = false + + sampling: std::optional } class PipelineOptions { - + output_path : std::filesystem::path - + log_path : std::filesystem::path + + output_path: std::filesystem::path + + log_path: std::filesystem::path } class ApplicationOptions { - + generator : GeneratorOptions - + pipeline : PipelineOptions + + generator: GeneratorOptions + + pipeline: PipelineOptions } ' --- Domain Model Relationships --- ApplicationOptions *-- GeneratorOptions ApplicationOptions *-- PipelineOptions - GeneratorOptions *-- SamplingOptions + GeneratorOptions o-- SamplingOptions } +@endum package "Domain: Policy" { diff --git a/tooling/pipeline/CMakeLists.txt b/tooling/pipeline/CMakeLists.txt index 651b985..24628fd 100644 --- a/tooling/pipeline/CMakeLists.txt +++ b/tooling/pipeline/CMakeLists.txt @@ -85,14 +85,14 @@ endif() FetchContent_Declare( llama-cpp GIT_REPOSITORY https://github.com/ggml-org/llama.cpp.git - GIT_TAG b8742 + GIT_TAG b8742 ) FetchContent_MakeAvailable(llama-cpp) FetchContent_Declare( boost-di GIT_REPOSITORY https://github.com/boost-ext/di.git - GIT_TAG v1.3.0 + GIT_TAG v1.3.0 ) FetchContent_MakeAvailable(boost-di) if(TARGET Boost.DI AND NOT TARGET boost::di) @@ -102,7 +102,7 @@ endif() FetchContent_Declare( spdlog GIT_REPOSITORY https://github.com/gabime/spdlog.git - GIT_TAG v1.15.3 + GIT_TAG v1.15.3 ) FetchContent_MakeAvailable(spdlog) @@ -121,8 +121,8 @@ set(SOURCES src/services/wikipedia/fetch_extract.cc src/services/sqlite/sqlite_export_service.cc src/services/sqlite/build_database_path.cc - src/services/sqlite/process_record.cc - src/services/sqlite/initialize.cc + src/services/sqlite/process_record.cc + src/services/sqlite/initialize.cc src/services/sqlite/finalize.cc src/web_client/curl_global_state.cc src/web_client/curl_web_client_get.cc @@ -139,8 +139,8 @@ set(SOURCES src/data_generation/mock/generate_brewery.cc src/data_generation/mock/generate_user.cc src/json_handling/json_loader.cc - src/services/sqlite/helpers/sqlite_connection_helpers.cpp - src/services/sqlite/helpers/sqlite_statement_helpers.cpp + src/services/sqlite/helpers/sqlite_connection_helpers.cpp + src/services/sqlite/helpers/sqlite_statement_helpers.cpp ) # ============================================================================= @@ -173,6 +173,6 @@ configure_file( add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory - ${CMAKE_SOURCE_DIR}/prompts - ${CMAKE_BINARY_DIR}/prompts + ${CMAKE_SOURCE_DIR}/prompts + ${CMAKE_BINARY_DIR}/prompts ) diff --git a/tooling/pipeline/includes/data_model/application_options.h b/tooling/pipeline/includes/data_model/application_options.h index 1d36bd7..08b7d8c 100644 --- a/tooling/pipeline/includes/data_model/application_options.h +++ b/tooling/pipeline/includes/data_model/application_options.h @@ -7,36 +7,62 @@ */ #include +#include +#include #include /** - * @brief Program options for the Biergarten pipeline application. + * @brief LLM sampling parameters. */ -struct ApplicationOptions { - /// @brief Path to the LLM model file (gguf format); mutually exclusive with - /// use_mocked. - std::string model_path; - - /// @brief Use mocked generator instead of LLM; mutually exclusive with - /// model_path. - bool use_mocked = false; - +struct SamplingOptions { /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). float temperature = 1.0F; - /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more - /// random). + /// @brief LLM nucleus sampling top-p parameter. float top_p = 0.95F; /// @brief LLM top-k sampling parameter. uint32_t top_k = 64; - /// @brief Context window size (tokens) for LLM inference. Higher values - /// support longer prompts but use more memory. + /// @brief Context window size (tokens). uint32_t n_ctx = 8192; - /// @brief Random seed for sampling (-1 for random, otherwise non-negative). + /// @brief Random seed (-1 for random, otherwise non-negative). int seed = -1; }; +/** + * @brief Configuration for the LLM generator component. + */ +struct GeneratorOptions { + /// @brief Path to the LLM model file (gguf format). + std::filesystem::path model_path; + + /// @brief Use mocked generator instead of actual LLM inference. + bool use_mocked = false; + + /// @brief Specific sampling parameters for this generator. + /// If nullopt, the application should use global defaults. + std::optional sampling; +}; + +/** + * @brief Configuration for the pipeline execution and output. + */ +struct PipelineOptions { + /// @brief Directory for generated artifacts. + std::filesystem::path output_path; + + /// @brief Path for application logs. + std::filesystem::path log_path; +}; + +/** + * @brief Root configuration object for the Biergarten pipeline. + */ +struct ApplicationOptions { + GeneratorOptions generator; + PipelineOptions pipeline; +}; + #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_ diff --git a/tooling/pipeline/includes/services/export_service.h b/tooling/pipeline/includes/services/export_service.h index 55d0b06..3b5d6c6 100644 --- a/tooling/pipeline/includes/services/export_service.h +++ b/tooling/pipeline/includes/services/export_service.h @@ -6,6 +6,8 @@ * @brief Abstraction for persisting generated brewery data. */ +#include + #include "data_model/generated_brewery.h" /** diff --git a/tooling/pipeline/includes/services/sqlite_export_service.h b/tooling/pipeline/includes/services/sqlite_export_service.h index 0fa998f..fdae8ff 100644 --- a/tooling/pipeline/includes/services/sqlite_export_service.h +++ b/tooling/pipeline/includes/services/sqlite_export_service.h @@ -11,6 +11,7 @@ #include #include +#include "data_model/application_options.h" #include "services/date_time_provider.h" #include "services/export_service.h" #include "services/sqlite_export_service_helpers.h" @@ -20,7 +21,7 @@ */ class SqliteExportService final : public IExportService { public: - SqliteExportService(); + explicit SqliteExportService(const ApplicationOptions& options); ~SqliteExportService() override; SqliteExportService(const SqliteExportService&) = delete; @@ -47,6 +48,7 @@ class SqliteExportService final : public IExportService { [[nodiscard]] static std::string BuildLocationKey(const Location& location); std::unique_ptr date_time_provider_; + std::filesystem::path output_path_; std::string run_timestamp_utc_; std::filesystem::path database_path_; SqliteDatabaseHandle db_handle_; diff --git a/tooling/pipeline/src/data_generation/llama/llama_generator.cc b/tooling/pipeline/src/data_generation/llama/llama_generator.cc index a854f48..5f28b1b 100644 --- a/tooling/pipeline/src/data_generation/llama/llama_generator.cc +++ b/tooling/pipeline/src/data_generation/llama/llama_generator.cc @@ -44,41 +44,44 @@ LlamaGenerator::LlamaGenerator( "LlamaGenerator: prompt formatter dependency must not be null"); } - if (options.temperature < 0.0F) { + const auto sampling = options.generator.sampling.value_or(SamplingOptions{}); + + if (sampling.temperature < 0.0F) { throw std::runtime_error( "LlamaGenerator: sampling temperature must be >= 0"); } - if (options.top_p <= 0.0F || options.top_p > 1.0F) { + if (sampling.top_p <= 0.0F || sampling.top_p > 1.0F) { throw std::runtime_error( "LlamaGenerator: sampling top-p must be in (0, 1]"); } - if (options.top_k == 0U) { + if (sampling.top_k == 0U) { throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0"); } - if (options.seed < -1) { + if (sampling.seed < -1) { throw std::runtime_error( "LlamaGenerator: seed must be >= 0, or -1 for random"); } - if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) { + if (sampling.n_ctx == 0 || sampling.n_ctx > kMaxContextSize) { throw std::runtime_error( "LlamaGenerator: context size must be in range [1, 32768]"); } - sampling_temperature_ = options.temperature; - sampling_top_p_ = options.top_p; - sampling_top_k_ = options.top_k; + sampling_temperature_ = sampling.temperature; + sampling_top_p_ = sampling.top_p; + sampling_top_k_ = sampling.top_k; - if (options.seed == -1) { + if (sampling.seed == -1) { std::random_device random_device; rng_.seed(random_device()); } else { - rng_.seed(static_cast(options.seed)); + rng_.seed(static_cast(sampling.seed)); } - n_ctx_ = options.n_ctx; + + n_ctx_ = sampling.n_ctx; this->Load(model_path); } diff --git a/tooling/pipeline/src/main.cc b/tooling/pipeline/src/main.cc index 2ce3779..6206f4f 100644 --- a/tooling/pipeline/src/main.cc +++ b/tooling/pipeline/src/main.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -45,28 +46,31 @@ std::optional ParseArguments(const int argc, char** argv) { opt("help,h", "Produce help message"); + // Generator Options opt("mocked", prog_opts::bool_switch(), "Use mocked generator for brewery/user data"); - opt("model,m", prog_opts::value()->default_value(""), "Path to LLM model (gguf)"); + // Sampling Options opt("temperature", prog_opts::value()->default_value(1.0F), "Sampling temperature (higher = more random)"); - opt("top-p", prog_opts::value()->default_value(0.95F), "Nucleus sampling top-p in (0,1] (higher = more random)"); - opt("top-k", prog_opts::value()->default_value(64), "Top-k sampling parameter (higher = more candidate tokens)"); - opt("n-ctx", prog_opts::value()->default_value(8192), - "Context window size in tokens (1-32768)"); - + "Context window size in tokens"); opt("seed", prog_opts::value()->default_value(-1), "Sampler seed: -1 for random, otherwise non-negative integer"); - // Handle the "no arguments" or "help" case + // Pipeline Options + opt("output,o", prog_opts::value()->default_value("output"), + "Directory for generated artifacts"); + opt("log-path", + prog_opts::value()->default_value("pipeline.log"), + "Path for application logs"); + if (argc == 1) { spdlog::info("Biergarten Pipeline"); std::stringstream usage_stream; @@ -76,20 +80,24 @@ std::optional ParseArguments(const int argc, char** argv) { } try { - prog_opts::variables_map variables_map; - prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), - variables_map); - prog_opts::notify(variables_map); + prog_opts::variables_map vm; + prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), vm); + prog_opts::notify(vm); - if (variables_map.contains("help")) { + if (vm.contains("help")) { std::stringstream help_stream; help_stream << "\n" << desc; spdlog::info(help_stream.str()); return std::nullopt; } - const auto use_mocked = variables_map["mocked"].as(); - const auto model_path = variables_map["model"].as(); + ApplicationOptions options; + + options.pipeline.output_path = vm["output"].as(); + options.pipeline.log_path = vm["log-path"].as(); + + const bool use_mocked = vm["mocked"].as(); + const std::string model_path = vm["model"].as(); if (use_mocked && !model_path.empty()) { spdlog::error( @@ -103,26 +111,29 @@ std::optional ParseArguments(const int argc, char** argv) { return std::nullopt; } - const bool has_llm_params = !variables_map["temperature"].defaulted() || - !variables_map["top-p"].defaulted() || - !variables_map["top-k"].defaulted() || - !variables_map["seed"].defaulted(); + options.generator.use_mocked = use_mocked; + options.generator.model_path = model_path; - if (use_mocked && has_llm_params) { - spdlog::warn( - "Sampling parameters (--temperature, --top-p, --top-k, --seed) are" - " ignored when using --mocked"); + const bool user_provided_sampling = + !vm["temperature"].defaulted() || !vm["top-p"].defaulted() || + !vm["top-k"].defaulted() || !vm["n-ctx"].defaulted() || + !vm["seed"].defaulted(); + + if (use_mocked) { + if (user_provided_sampling) { + spdlog::warn("Sampling parameters are ignored when using --mocked"); + } + } else if (user_provided_sampling) { + SamplingOptions sampling; + sampling.temperature = vm["temperature"].as(); + sampling.top_p = vm["top-p"].as(); + sampling.top_k = vm["top-k"].as(); + sampling.n_ctx = vm["n-ctx"].as(); + sampling.seed = vm["seed"].as(); + + options.generator.sampling = sampling; } - ApplicationOptions options; - options.use_mocked = use_mocked; - options.model_path = model_path; - options.temperature = variables_map["temperature"].as(); - options.top_p = variables_map["top-p"].as(); - options.top_k = variables_map["top-k"].as(); - options.n_ctx = variables_map["n-ctx"].as(); - options.seed = variables_map["seed"].as(); - return options; } catch (const std::exception& exception) { spdlog::error("Failed to parse command-line arguments: {}", @@ -157,6 +168,9 @@ int main(const int argc, char** argv) { } const auto options = *parsed_options; + const std::string model_path = options.generator.model_path.string(); + const auto sampling = + options.generator.sampling.value_or(SamplingOptions{}); const auto injector = di::make_injector( di::bind().to(), @@ -164,10 +178,11 @@ int main(const int argc, char** argv) { di::bind().to(), di::bind().to(), di::bind().to(), - di::bind().to(options.model_path), + di::bind().to(model_path), di::bind().to( - [options](const auto& inj) -> std::unique_ptr { - if (options.use_mocked) { + [options, model_path, + sampling](const auto& inj) -> std::unique_ptr { + if (options.generator.use_mocked) { spdlog::info( "[Generator] Using MockGenerator (no model path provided)"); return std::make_unique(); @@ -176,8 +191,8 @@ int main(const int argc, char** argv) { spdlog::info( "[Generator] Using LlamaGenerator: {} (temperature={}, " "top-p={}, top-k={}, n_ctx={}, seed={})", - options.model_path, options.temperature, options.top_p, - options.top_k, options.n_ctx, options.seed); + model_path, sampling.temperature, sampling.top_p, + sampling.top_k, sampling.n_ctx, sampling.seed); return inj.template create>(); })); diff --git a/tooling/pipeline/src/services/sqlite/build_database_path.cc b/tooling/pipeline/src/services/sqlite/build_database_path.cc index 3a96cdf..8786fe4 100644 --- a/tooling/pipeline/src/services/sqlite/build_database_path.cc +++ b/tooling/pipeline/src/services/sqlite/build_database_path.cc @@ -11,11 +11,10 @@ std::filesystem::path SqliteExportService::BuildDatabasePath() const { std::filesystem::path base_filename("biergarten_seed_" + run_timestamp_utc_ + ".sqlite"); - std::filesystem::path candidate = - std::filesystem::current_path() / base_filename; + std::filesystem::path candidate = output_path_ / base_filename; for (int suffix = 1; std::filesystem::exists(candidate); ++suffix) { - candidate = std::filesystem::current_path() / + candidate = output_path_ / std::filesystem::path("biergarten_seed_" + run_timestamp_utc_ + "-" + std::to_string(suffix) + ".sqlite"); } diff --git a/tooling/pipeline/src/services/sqlite/sqlite_export_service.cc b/tooling/pipeline/src/services/sqlite/sqlite_export_service.cc index 377c917..4bf66a3 100644 --- a/tooling/pipeline/src/services/sqlite/sqlite_export_service.cc +++ b/tooling/pipeline/src/services/sqlite/sqlite_export_service.cc @@ -7,11 +7,12 @@ #include -SqliteExportService::SqliteExportService() - : date_time_provider_(std::make_unique()) {} +SqliteExportService::SqliteExportService(const ApplicationOptions& options) + : date_time_provider_(std::make_unique()), + output_path_(options.pipeline.output_path) {} SqliteExportService::~SqliteExportService() { if (db_handle_ != nullptr) { RollbackAndCloseNoThrow(); } -} \ No newline at end of file +}