mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Refactor ApplicationOptions to separate config concerns
This commit is contained in:
@@ -141,37 +141,38 @@ package "Domain: Models" {
|
|||||||
|
|
||||||
LocationContext *-- Completeness
|
LocationContext *-- Completeness
|
||||||
}
|
}
|
||||||
|
@startuml
|
||||||
package "Domain: Application Configuration"{
|
package "Domain: Application Configuration" {
|
||||||
class SamplingOptions {
|
class SamplingOptions {
|
||||||
+ temperature : float = 1.0F
|
+ temperature: float = 1.0F
|
||||||
+ top_p : float = 0.95F
|
+ top_p: float = 0.95F
|
||||||
+ top_k : uint32_t = 64
|
+ top_k: uint32_t = 64
|
||||||
+ n_ctx : uint32_t = 8192
|
+ n_ctx: uint32_t = 8192
|
||||||
+ seed : int = -1
|
+ seed: int = -1
|
||||||
}
|
}
|
||||||
|
|
||||||
class GeneratorOptions {
|
class GeneratorOptions {
|
||||||
+ model_path : std::filesystem::path
|
+ model_path: std::filesystem::path
|
||||||
+ use_mocked : bool = false
|
+ use_mocked: bool = false
|
||||||
+ sampling : SamplingOptions
|
+ sampling: std::optional<SamplingOptions>
|
||||||
}
|
}
|
||||||
|
|
||||||
class PipelineOptions {
|
class PipelineOptions {
|
||||||
+ output_path : std::filesystem::path
|
+ output_path: std::filesystem::path
|
||||||
+ log_path : std::filesystem::path
|
+ log_path: std::filesystem::path
|
||||||
}
|
}
|
||||||
|
|
||||||
class ApplicationOptions {
|
class ApplicationOptions {
|
||||||
+ generator : GeneratorOptions
|
+ generator: GeneratorOptions
|
||||||
+ pipeline : PipelineOptions
|
+ pipeline: PipelineOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
' --- Domain Model Relationships ---
|
' --- Domain Model Relationships ---
|
||||||
ApplicationOptions *-- GeneratorOptions
|
ApplicationOptions *-- GeneratorOptions
|
||||||
ApplicationOptions *-- PipelineOptions
|
ApplicationOptions *-- PipelineOptions
|
||||||
GeneratorOptions *-- SamplingOptions
|
GeneratorOptions o-- SamplingOptions
|
||||||
}
|
}
|
||||||
|
@endum
|
||||||
|
|
||||||
package "Domain: Policy" {
|
package "Domain: Policy" {
|
||||||
|
|
||||||
|
|||||||
@@ -85,14 +85,14 @@ endif()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
llama-cpp
|
llama-cpp
|
||||||
GIT_REPOSITORY https://github.com/ggml-org/llama.cpp.git
|
GIT_REPOSITORY https://github.com/ggml-org/llama.cpp.git
|
||||||
GIT_TAG b8742
|
GIT_TAG b8742
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(llama-cpp)
|
FetchContent_MakeAvailable(llama-cpp)
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
boost-di
|
boost-di
|
||||||
GIT_REPOSITORY https://github.com/boost-ext/di.git
|
GIT_REPOSITORY https://github.com/boost-ext/di.git
|
||||||
GIT_TAG v1.3.0
|
GIT_TAG v1.3.0
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(boost-di)
|
FetchContent_MakeAvailable(boost-di)
|
||||||
if(TARGET Boost.DI AND NOT TARGET boost::di)
|
if(TARGET Boost.DI AND NOT TARGET boost::di)
|
||||||
@@ -102,7 +102,7 @@ endif()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
spdlog
|
spdlog
|
||||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
||||||
GIT_TAG v1.15.3
|
GIT_TAG v1.15.3
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(spdlog)
|
FetchContent_MakeAvailable(spdlog)
|
||||||
|
|
||||||
@@ -121,8 +121,8 @@ set(SOURCES
|
|||||||
src/services/wikipedia/fetch_extract.cc
|
src/services/wikipedia/fetch_extract.cc
|
||||||
src/services/sqlite/sqlite_export_service.cc
|
src/services/sqlite/sqlite_export_service.cc
|
||||||
src/services/sqlite/build_database_path.cc
|
src/services/sqlite/build_database_path.cc
|
||||||
src/services/sqlite/process_record.cc
|
src/services/sqlite/process_record.cc
|
||||||
src/services/sqlite/initialize.cc
|
src/services/sqlite/initialize.cc
|
||||||
src/services/sqlite/finalize.cc
|
src/services/sqlite/finalize.cc
|
||||||
src/web_client/curl_global_state.cc
|
src/web_client/curl_global_state.cc
|
||||||
src/web_client/curl_web_client_get.cc
|
src/web_client/curl_web_client_get.cc
|
||||||
@@ -139,8 +139,8 @@ set(SOURCES
|
|||||||
src/data_generation/mock/generate_brewery.cc
|
src/data_generation/mock/generate_brewery.cc
|
||||||
src/data_generation/mock/generate_user.cc
|
src/data_generation/mock/generate_user.cc
|
||||||
src/json_handling/json_loader.cc
|
src/json_handling/json_loader.cc
|
||||||
src/services/sqlite/helpers/sqlite_connection_helpers.cpp
|
src/services/sqlite/helpers/sqlite_connection_helpers.cpp
|
||||||
src/services/sqlite/helpers/sqlite_statement_helpers.cpp
|
src/services/sqlite/helpers/sqlite_statement_helpers.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -173,6 +173,6 @@ configure_file(
|
|||||||
|
|
||||||
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
|
||||||
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||||
${CMAKE_SOURCE_DIR}/prompts
|
${CMAKE_SOURCE_DIR}/prompts
|
||||||
${CMAKE_BINARY_DIR}/prompts
|
${CMAKE_BINARY_DIR}/prompts
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,36 +7,62 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Program options for the Biergarten pipeline application.
|
* @brief LLM sampling parameters.
|
||||||
*/
|
*/
|
||||||
struct ApplicationOptions {
|
struct SamplingOptions {
|
||||||
/// @brief Path to the LLM model file (gguf format); mutually exclusive with
|
|
||||||
/// use_mocked.
|
|
||||||
std::string model_path;
|
|
||||||
|
|
||||||
/// @brief Use mocked generator instead of LLM; mutually exclusive with
|
|
||||||
/// model_path.
|
|
||||||
bool use_mocked = false;
|
|
||||||
|
|
||||||
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
|
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
|
||||||
float temperature = 1.0F;
|
float temperature = 1.0F;
|
||||||
|
|
||||||
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
|
/// @brief LLM nucleus sampling top-p parameter.
|
||||||
/// random).
|
|
||||||
float top_p = 0.95F;
|
float top_p = 0.95F;
|
||||||
|
|
||||||
/// @brief LLM top-k sampling parameter.
|
/// @brief LLM top-k sampling parameter.
|
||||||
uint32_t top_k = 64;
|
uint32_t top_k = 64;
|
||||||
|
|
||||||
/// @brief Context window size (tokens) for LLM inference. Higher values
|
/// @brief Context window size (tokens).
|
||||||
/// support longer prompts but use more memory.
|
|
||||||
uint32_t n_ctx = 8192;
|
uint32_t n_ctx = 8192;
|
||||||
|
|
||||||
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
|
/// @brief Random seed (-1 for random, otherwise non-negative).
|
||||||
int seed = -1;
|
int seed = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Configuration for the LLM generator component.
|
||||||
|
*/
|
||||||
|
struct GeneratorOptions {
|
||||||
|
/// @brief Path to the LLM model file (gguf format).
|
||||||
|
std::filesystem::path model_path;
|
||||||
|
|
||||||
|
/// @brief Use mocked generator instead of actual LLM inference.
|
||||||
|
bool use_mocked = false;
|
||||||
|
|
||||||
|
/// @brief Specific sampling parameters for this generator.
|
||||||
|
/// If nullopt, the application should use global defaults.
|
||||||
|
std::optional<SamplingOptions> sampling;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Configuration for the pipeline execution and output.
|
||||||
|
*/
|
||||||
|
struct PipelineOptions {
|
||||||
|
/// @brief Directory for generated artifacts.
|
||||||
|
std::filesystem::path output_path;
|
||||||
|
|
||||||
|
/// @brief Path for application logs.
|
||||||
|
std::filesystem::path log_path;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Root configuration object for the Biergarten pipeline.
|
||||||
|
*/
|
||||||
|
struct ApplicationOptions {
|
||||||
|
GeneratorOptions generator;
|
||||||
|
PipelineOptions pipeline;
|
||||||
|
};
|
||||||
|
|
||||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_
|
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
* @brief Abstraction for persisting generated brewery data.
|
* @brief Abstraction for persisting generated brewery data.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
#include "data_model/generated_brewery.h"
|
#include "data_model/generated_brewery.h"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "data_model/application_options.h"
|
||||||
#include "services/date_time_provider.h"
|
#include "services/date_time_provider.h"
|
||||||
#include "services/export_service.h"
|
#include "services/export_service.h"
|
||||||
#include "services/sqlite_export_service_helpers.h"
|
#include "services/sqlite_export_service_helpers.h"
|
||||||
@@ -20,7 +21,7 @@
|
|||||||
*/
|
*/
|
||||||
class SqliteExportService final : public IExportService {
|
class SqliteExportService final : public IExportService {
|
||||||
public:
|
public:
|
||||||
SqliteExportService();
|
explicit SqliteExportService(const ApplicationOptions& options);
|
||||||
~SqliteExportService() override;
|
~SqliteExportService() override;
|
||||||
|
|
||||||
SqliteExportService(const SqliteExportService&) = delete;
|
SqliteExportService(const SqliteExportService&) = delete;
|
||||||
@@ -47,6 +48,7 @@ class SqliteExportService final : public IExportService {
|
|||||||
[[nodiscard]] static std::string BuildLocationKey(const Location& location);
|
[[nodiscard]] static std::string BuildLocationKey(const Location& location);
|
||||||
|
|
||||||
std::unique_ptr<IDateTimeProvider> date_time_provider_;
|
std::unique_ptr<IDateTimeProvider> date_time_provider_;
|
||||||
|
std::filesystem::path output_path_;
|
||||||
std::string run_timestamp_utc_;
|
std::string run_timestamp_utc_;
|
||||||
std::filesystem::path database_path_;
|
std::filesystem::path database_path_;
|
||||||
SqliteDatabaseHandle db_handle_;
|
SqliteDatabaseHandle db_handle_;
|
||||||
|
|||||||
@@ -44,41 +44,44 @@ LlamaGenerator::LlamaGenerator(
|
|||||||
"LlamaGenerator: prompt formatter dependency must not be null");
|
"LlamaGenerator: prompt formatter dependency must not be null");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.temperature < 0.0F) {
|
const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
|
||||||
|
|
||||||
|
if (sampling.temperature < 0.0F) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: sampling temperature must be >= 0");
|
"LlamaGenerator: sampling temperature must be >= 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
|
if (sampling.top_p <= 0.0F || sampling.top_p > 1.0F) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: sampling top-p must be in (0, 1]");
|
"LlamaGenerator: sampling top-p must be in (0, 1]");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.top_k == 0U) {
|
if (sampling.top_k == 0U) {
|
||||||
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
|
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.seed < -1) {
|
if (sampling.seed < -1) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: seed must be >= 0, or -1 for random");
|
"LlamaGenerator: seed must be >= 0, or -1 for random");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) {
|
if (sampling.n_ctx == 0 || sampling.n_ctx > kMaxContextSize) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: context size must be in range [1, 32768]");
|
"LlamaGenerator: context size must be in range [1, 32768]");
|
||||||
}
|
}
|
||||||
|
|
||||||
sampling_temperature_ = options.temperature;
|
sampling_temperature_ = sampling.temperature;
|
||||||
sampling_top_p_ = options.top_p;
|
sampling_top_p_ = sampling.top_p;
|
||||||
sampling_top_k_ = options.top_k;
|
sampling_top_k_ = sampling.top_k;
|
||||||
|
|
||||||
if (options.seed == -1) {
|
if (sampling.seed == -1) {
|
||||||
std::random_device random_device;
|
std::random_device random_device;
|
||||||
rng_.seed(random_device());
|
rng_.seed(random_device());
|
||||||
} else {
|
} else {
|
||||||
rng_.seed(static_cast<uint32_t>(options.seed));
|
rng_.seed(static_cast<uint32_t>(sampling.seed));
|
||||||
}
|
}
|
||||||
n_ctx_ = options.n_ctx;
|
|
||||||
|
n_ctx_ = sampling.n_ctx;
|
||||||
|
|
||||||
this->Load(model_path);
|
this->Load(model_path);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <boost/di.hpp>
|
#include <boost/di.hpp>
|
||||||
#include <boost/program_options.hpp>
|
#include <boost/program_options.hpp>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <cstdint>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@@ -45,28 +46,31 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
|
|||||||
|
|
||||||
opt("help,h", "Produce help message");
|
opt("help,h", "Produce help message");
|
||||||
|
|
||||||
|
// Generator Options
|
||||||
opt("mocked", prog_opts::bool_switch(),
|
opt("mocked", prog_opts::bool_switch(),
|
||||||
"Use mocked generator for brewery/user data");
|
"Use mocked generator for brewery/user data");
|
||||||
|
|
||||||
opt("model,m", prog_opts::value<std::string>()->default_value(""),
|
opt("model,m", prog_opts::value<std::string>()->default_value(""),
|
||||||
"Path to LLM model (gguf)");
|
"Path to LLM model (gguf)");
|
||||||
|
|
||||||
|
// Sampling Options
|
||||||
opt("temperature", prog_opts::value<float>()->default_value(1.0F),
|
opt("temperature", prog_opts::value<float>()->default_value(1.0F),
|
||||||
"Sampling temperature (higher = more random)");
|
"Sampling temperature (higher = more random)");
|
||||||
|
|
||||||
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
|
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
|
||||||
"Nucleus sampling top-p in (0,1] (higher = more random)");
|
"Nucleus sampling top-p in (0,1] (higher = more random)");
|
||||||
|
|
||||||
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
|
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
|
||||||
"Top-k sampling parameter (higher = more candidate tokens)");
|
"Top-k sampling parameter (higher = more candidate tokens)");
|
||||||
|
|
||||||
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
|
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
|
||||||
"Context window size in tokens (1-32768)");
|
"Context window size in tokens");
|
||||||
|
|
||||||
opt("seed", prog_opts::value<int>()->default_value(-1),
|
opt("seed", prog_opts::value<int>()->default_value(-1),
|
||||||
"Sampler seed: -1 for random, otherwise non-negative integer");
|
"Sampler seed: -1 for random, otherwise non-negative integer");
|
||||||
|
|
||||||
// Handle the "no arguments" or "help" case
|
// Pipeline Options
|
||||||
|
opt("output,o", prog_opts::value<std::string>()->default_value("output"),
|
||||||
|
"Directory for generated artifacts");
|
||||||
|
opt("log-path",
|
||||||
|
prog_opts::value<std::string>()->default_value("pipeline.log"),
|
||||||
|
"Path for application logs");
|
||||||
|
|
||||||
if (argc == 1) {
|
if (argc == 1) {
|
||||||
spdlog::info("Biergarten Pipeline");
|
spdlog::info("Biergarten Pipeline");
|
||||||
std::stringstream usage_stream;
|
std::stringstream usage_stream;
|
||||||
@@ -76,20 +80,24 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
prog_opts::variables_map variables_map;
|
prog_opts::variables_map vm;
|
||||||
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
|
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), vm);
|
||||||
variables_map);
|
prog_opts::notify(vm);
|
||||||
prog_opts::notify(variables_map);
|
|
||||||
|
|
||||||
if (variables_map.contains("help")) {
|
if (vm.contains("help")) {
|
||||||
std::stringstream help_stream;
|
std::stringstream help_stream;
|
||||||
help_stream << "\n" << desc;
|
help_stream << "\n" << desc;
|
||||||
spdlog::info(help_stream.str());
|
spdlog::info(help_stream.str());
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto use_mocked = variables_map["mocked"].as<bool>();
|
ApplicationOptions options;
|
||||||
const auto model_path = variables_map["model"].as<std::string>();
|
|
||||||
|
options.pipeline.output_path = vm["output"].as<std::string>();
|
||||||
|
options.pipeline.log_path = vm["log-path"].as<std::string>();
|
||||||
|
|
||||||
|
const bool use_mocked = vm["mocked"].as<bool>();
|
||||||
|
const std::string model_path = vm["model"].as<std::string>();
|
||||||
|
|
||||||
if (use_mocked && !model_path.empty()) {
|
if (use_mocked && !model_path.empty()) {
|
||||||
spdlog::error(
|
spdlog::error(
|
||||||
@@ -103,26 +111,29 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
|
options.generator.use_mocked = use_mocked;
|
||||||
!variables_map["top-p"].defaulted() ||
|
options.generator.model_path = model_path;
|
||||||
!variables_map["top-k"].defaulted() ||
|
|
||||||
!variables_map["seed"].defaulted();
|
|
||||||
|
|
||||||
if (use_mocked && has_llm_params) {
|
const bool user_provided_sampling =
|
||||||
spdlog::warn(
|
!vm["temperature"].defaulted() || !vm["top-p"].defaulted() ||
|
||||||
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
|
!vm["top-k"].defaulted() || !vm["n-ctx"].defaulted() ||
|
||||||
" ignored when using --mocked");
|
!vm["seed"].defaulted();
|
||||||
|
|
||||||
|
if (use_mocked) {
|
||||||
|
if (user_provided_sampling) {
|
||||||
|
spdlog::warn("Sampling parameters are ignored when using --mocked");
|
||||||
|
}
|
||||||
|
} else if (user_provided_sampling) {
|
||||||
|
SamplingOptions sampling;
|
||||||
|
sampling.temperature = vm["temperature"].as<float>();
|
||||||
|
sampling.top_p = vm["top-p"].as<float>();
|
||||||
|
sampling.top_k = vm["top-k"].as<uint32_t>();
|
||||||
|
sampling.n_ctx = vm["n-ctx"].as<uint32_t>();
|
||||||
|
sampling.seed = vm["seed"].as<int>();
|
||||||
|
|
||||||
|
options.generator.sampling = sampling;
|
||||||
}
|
}
|
||||||
|
|
||||||
ApplicationOptions options;
|
|
||||||
options.use_mocked = use_mocked;
|
|
||||||
options.model_path = model_path;
|
|
||||||
options.temperature = variables_map["temperature"].as<float>();
|
|
||||||
options.top_p = variables_map["top-p"].as<float>();
|
|
||||||
options.top_k = variables_map["top-k"].as<uint32_t>();
|
|
||||||
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
|
|
||||||
options.seed = variables_map["seed"].as<int>();
|
|
||||||
|
|
||||||
return options;
|
return options;
|
||||||
} catch (const std::exception& exception) {
|
} catch (const std::exception& exception) {
|
||||||
spdlog::error("Failed to parse command-line arguments: {}",
|
spdlog::error("Failed to parse command-line arguments: {}",
|
||||||
@@ -157,6 +168,9 @@ int main(const int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto options = *parsed_options;
|
const auto options = *parsed_options;
|
||||||
|
const std::string model_path = options.generator.model_path.string();
|
||||||
|
const auto sampling =
|
||||||
|
options.generator.sampling.value_or(SamplingOptions{});
|
||||||
|
|
||||||
const auto injector = di::make_injector(
|
const auto injector = di::make_injector(
|
||||||
di::bind<WebClient>().to<CURLWebClient>(),
|
di::bind<WebClient>().to<CURLWebClient>(),
|
||||||
@@ -164,10 +178,11 @@ int main(const int argc, char** argv) {
|
|||||||
di::bind<IEnrichmentService>().to<WikipediaService>(),
|
di::bind<IEnrichmentService>().to<WikipediaService>(),
|
||||||
di::bind<IExportService>().to<SqliteExportService>(),
|
di::bind<IExportService>().to<SqliteExportService>(),
|
||||||
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
|
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
|
||||||
di::bind<std::string>().to(options.model_path),
|
di::bind<std::string>().to(model_path),
|
||||||
di::bind<DataGenerator>().to(
|
di::bind<DataGenerator>().to(
|
||||||
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
|
[options, model_path,
|
||||||
if (options.use_mocked) {
|
sampling](const auto& inj) -> std::unique_ptr<DataGenerator> {
|
||||||
|
if (options.generator.use_mocked) {
|
||||||
spdlog::info(
|
spdlog::info(
|
||||||
"[Generator] Using MockGenerator (no model path provided)");
|
"[Generator] Using MockGenerator (no model path provided)");
|
||||||
return std::make_unique<MockGenerator>();
|
return std::make_unique<MockGenerator>();
|
||||||
@@ -176,8 +191,8 @@ int main(const int argc, char** argv) {
|
|||||||
spdlog::info(
|
spdlog::info(
|
||||||
"[Generator] Using LlamaGenerator: {} (temperature={}, "
|
"[Generator] Using LlamaGenerator: {} (temperature={}, "
|
||||||
"top-p={}, top-k={}, n_ctx={}, seed={})",
|
"top-p={}, top-k={}, n_ctx={}, seed={})",
|
||||||
options.model_path, options.temperature, options.top_p,
|
model_path, sampling.temperature, sampling.top_p,
|
||||||
options.top_k, options.n_ctx, options.seed);
|
sampling.top_k, sampling.n_ctx, sampling.seed);
|
||||||
return inj.template create<std::unique_ptr<LlamaGenerator>>();
|
return inj.template create<std::unique_ptr<LlamaGenerator>>();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|||||||
@@ -11,11 +11,10 @@
|
|||||||
std::filesystem::path SqliteExportService::BuildDatabasePath() const {
|
std::filesystem::path SqliteExportService::BuildDatabasePath() const {
|
||||||
std::filesystem::path base_filename("biergarten_seed_" + run_timestamp_utc_ +
|
std::filesystem::path base_filename("biergarten_seed_" + run_timestamp_utc_ +
|
||||||
".sqlite");
|
".sqlite");
|
||||||
std::filesystem::path candidate =
|
std::filesystem::path candidate = output_path_ / base_filename;
|
||||||
std::filesystem::current_path() / base_filename;
|
|
||||||
|
|
||||||
for (int suffix = 1; std::filesystem::exists(candidate); ++suffix) {
|
for (int suffix = 1; std::filesystem::exists(candidate); ++suffix) {
|
||||||
candidate = std::filesystem::current_path() /
|
candidate = output_path_ /
|
||||||
std::filesystem::path("biergarten_seed_" + run_timestamp_utc_ +
|
std::filesystem::path("biergarten_seed_" + run_timestamp_utc_ +
|
||||||
"-" + std::to_string(suffix) + ".sqlite");
|
"-" + std::to_string(suffix) + ".sqlite");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,11 +7,12 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
SqliteExportService::SqliteExportService()
|
SqliteExportService::SqliteExportService(const ApplicationOptions& options)
|
||||||
: date_time_provider_(std::make_unique<SystemDateTimeProvider>()) {}
|
: date_time_provider_(std::make_unique<SystemDateTimeProvider>()),
|
||||||
|
output_path_(options.pipeline.output_path) {}
|
||||||
|
|
||||||
SqliteExportService::~SqliteExportService() {
|
SqliteExportService::~SqliteExportService() {
|
||||||
if (db_handle_ != nullptr) {
|
if (db_handle_ != nullptr) {
|
||||||
RollbackAndCloseNoThrow();
|
RollbackAndCloseNoThrow();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user