Refactor ApplicationOptions to separate config concerns

This commit is contained in:
Aaron Po
2026-05-01 00:40:21 -04:00
parent 641a479b6a
commit 01849062d5
9 changed files with 142 additions and 93 deletions

View File

@@ -44,41 +44,44 @@ LlamaGenerator::LlamaGenerator(
"LlamaGenerator: prompt formatter dependency must not be null");
}
if (options.temperature < 0.0F) {
const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
if (sampling.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
if (sampling.top_p <= 0.0F || sampling.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.top_k == 0U) {
if (sampling.top_k == 0U) {
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
}
if (options.seed < -1) {
if (sampling.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) {
if (sampling.n_ctx == 0 || sampling.n_ctx > kMaxContextSize) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_top_k_ = options.top_k;
sampling_temperature_ = sampling.temperature;
sampling_top_p_ = sampling.top_p;
sampling_top_k_ = sampling.top_k;
if (options.seed == -1) {
if (sampling.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
rng_.seed(static_cast<uint32_t>(sampling.seed));
}
n_ctx_ = options.n_ctx;
n_ctx_ = sampling.n_ctx;
this->Load(model_path);
}

View File

@@ -9,6 +9,7 @@
#include <boost/di.hpp>
#include <boost/program_options.hpp>
#include <chrono>
#include <cstdint>
#include <exception>
#include <memory>
#include <optional>
@@ -45,28 +46,31 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
opt("help,h", "Produce help message");
// Generator Options
opt("mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data");
opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)");
// Sampling Options
opt("temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)");
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)");
"Context window size in tokens");
opt("seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
// Handle the "no arguments" or "help" case
// Pipeline Options
opt("output,o", prog_opts::value<std::string>()->default_value("output"),
"Directory for generated artifacts");
opt("log-path",
prog_opts::value<std::string>()->default_value("pipeline.log"),
"Path for application logs");
if (argc == 1) {
spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream;
@@ -76,20 +80,24 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
}
try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
prog_opts::variables_map vm;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), vm);
prog_opts::notify(vm);
if (variables_map.contains("help")) {
if (vm.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return std::nullopt;
}
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
ApplicationOptions options;
options.pipeline.output_path = vm["output"].as<std::string>();
options.pipeline.log_path = vm["log-path"].as<std::string>();
const bool use_mocked = vm["mocked"].as<bool>();
const std::string model_path = vm["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error(
@@ -103,26 +111,29 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
return std::nullopt;
}
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted();
options.generator.use_mocked = use_mocked;
options.generator.model_path = model_path;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
const bool user_provided_sampling =
!vm["temperature"].defaulted() || !vm["top-p"].defaulted() ||
!vm["top-k"].defaulted() || !vm["n-ctx"].defaulted() ||
!vm["seed"].defaulted();
if (use_mocked) {
if (user_provided_sampling) {
spdlog::warn("Sampling parameters are ignored when using --mocked");
}
} else if (user_provided_sampling) {
SamplingOptions sampling;
sampling.temperature = vm["temperature"].as<float>();
sampling.top_p = vm["top-p"].as<float>();
sampling.top_k = vm["top-k"].as<uint32_t>();
sampling.n_ctx = vm["n-ctx"].as<uint32_t>();
sampling.seed = vm["seed"].as<int>();
options.generator.sampling = sampling;
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
@@ -157,6 +168,9 @@ int main(const int argc, char** argv) {
}
const auto options = *parsed_options;
const std::string model_path = options.generator.model_path.string();
const auto sampling =
options.generator.sampling.value_or(SamplingOptions{});
const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
@@ -164,10 +178,11 @@ int main(const int argc, char** argv) {
di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<IExportService>().to<SqliteExportService>(),
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
di::bind<std::string>().to(options.model_path),
di::bind<std::string>().to(model_path),
di::bind<DataGenerator>().to(
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
[options, model_path,
sampling](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) {
spdlog::info(
"[Generator] Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>();
@@ -176,8 +191,8 @@ int main(const int argc, char** argv) {
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed);
model_path, sampling.temperature, sampling.top_p,
sampling.top_k, sampling.n_ctx, sampling.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>();
}));

View File

@@ -11,11 +11,10 @@
std::filesystem::path SqliteExportService::BuildDatabasePath() const {
std::filesystem::path base_filename("biergarten_seed_" + run_timestamp_utc_ +
".sqlite");
std::filesystem::path candidate =
std::filesystem::current_path() / base_filename;
std::filesystem::path candidate = output_path_ / base_filename;
for (int suffix = 1; std::filesystem::exists(candidate); ++suffix) {
candidate = std::filesystem::current_path() /
candidate = output_path_ /
std::filesystem::path("biergarten_seed_" + run_timestamp_utc_ +
"-" + std::to_string(suffix) + ".sqlite");
}

View File

@@ -7,11 +7,12 @@
#include <memory>
SqliteExportService::SqliteExportService()
: date_time_provider_(std::make_unique<SystemDateTimeProvider>()) {}
SqliteExportService::SqliteExportService(const ApplicationOptions& options)
: date_time_provider_(std::make_unique<SystemDateTimeProvider>()),
output_path_(options.pipeline.output_path) {}
SqliteExportService::~SqliteExportService() {
if (db_handle_ != nullptr) {
RollbackAndCloseNoThrow();
}
}
}