2 Commits

Author SHA1 Message Date
Aaron Po
ec435df4ad add prompt dir app option 2026-05-01 11:55:44 -04:00
Aaron Po
01849062d5 Refactor ApplicationOptions to separate config concerns 2026-05-01 00:40:21 -04:00
15 changed files with 358 additions and 166 deletions

View File

@@ -141,37 +141,38 @@ package "Domain: Models" {
LocationContext *-- Completeness LocationContext *-- Completeness
} }
@startuml
package "Domain: Application Configuration"{ package "Domain: Application Configuration" {
class SamplingOptions { class SamplingOptions {
+ temperature : float = 1.0F + temperature: float = 1.0F
+ top_p : float = 0.95F + top_p: float = 0.95F
+ top_k : uint32_t = 64 + top_k: uint32_t = 64
+ n_ctx : uint32_t = 8192 + n_ctx: uint32_t = 8192
+ seed : int = -1 + seed: int = -1
} }
class GeneratorOptions { class GeneratorOptions {
+ model_path : std::filesystem::path + model_path: std::filesystem::path
+ use_mocked : bool = false + use_mocked: bool = false
+ sampling : SamplingOptions + sampling: std::optional<SamplingOptions>
} }
class PipelineOptions { class PipelineOptions {
+ output_path : std::filesystem::path + output_path: std::filesystem::path
+ log_path : std::filesystem::path + log_path: std::filesystem::path
} }
class ApplicationOptions { class ApplicationOptions {
+ generator : GeneratorOptions + generator: GeneratorOptions
+ pipeline : PipelineOptions + pipeline: PipelineOptions
} }
' --- Domain Model Relationships --- ' --- Domain Model Relationships ---
ApplicationOptions *-- GeneratorOptions ApplicationOptions *-- GeneratorOptions
ApplicationOptions *-- PipelineOptions ApplicationOptions *-- PipelineOptions
GeneratorOptions *-- SamplingOptions GeneratorOptions o-- SamplingOptions
} }
@endum
package "Domain: Policy" { package "Domain: Policy" {

View File

@@ -85,14 +85,14 @@ endif()
FetchContent_Declare( FetchContent_Declare(
llama-cpp llama-cpp
GIT_REPOSITORY https://github.com/ggml-org/llama.cpp.git GIT_REPOSITORY https://github.com/ggml-org/llama.cpp.git
GIT_TAG b8742 GIT_TAG b8742
) )
FetchContent_MakeAvailable(llama-cpp) FetchContent_MakeAvailable(llama-cpp)
FetchContent_Declare( FetchContent_Declare(
boost-di boost-di
GIT_REPOSITORY https://github.com/boost-ext/di.git GIT_REPOSITORY https://github.com/boost-ext/di.git
GIT_TAG v1.3.0 GIT_TAG v1.3.0
) )
FetchContent_MakeAvailable(boost-di) FetchContent_MakeAvailable(boost-di)
if(TARGET Boost.DI AND NOT TARGET boost::di) if(TARGET Boost.DI AND NOT TARGET boost::di)
@@ -102,7 +102,7 @@ endif()
FetchContent_Declare( FetchContent_Declare(
spdlog spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v1.15.3 GIT_TAG v1.15.3
) )
FetchContent_MakeAvailable(spdlog) FetchContent_MakeAvailable(spdlog)
@@ -121,8 +121,8 @@ set(SOURCES
src/services/wikipedia/fetch_extract.cc src/services/wikipedia/fetch_extract.cc
src/services/sqlite/sqlite_export_service.cc src/services/sqlite/sqlite_export_service.cc
src/services/sqlite/build_database_path.cc src/services/sqlite/build_database_path.cc
src/services/sqlite/process_record.cc src/services/sqlite/process_record.cc
src/services/sqlite/initialize.cc src/services/sqlite/initialize.cc
src/services/sqlite/finalize.cc src/services/sqlite/finalize.cc
src/web_client/curl_global_state.cc src/web_client/curl_global_state.cc
src/web_client/curl_web_client_get.cc src/web_client/curl_web_client_get.cc
@@ -133,14 +133,14 @@ set(SOURCES
src/data_generation/llama/helpers.cc src/data_generation/llama/helpers.cc
src/data_generation/llama/infer.cc src/data_generation/llama/infer.cc
src/data_generation/llama/load.cc src/data_generation/llama/load.cc
src/data_generation/llama/load_brewery_prompt.cc src/services/prompt_directory.cc
src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc
src/data_generation/mock/deterministic_hash.cc src/data_generation/mock/deterministic_hash.cc
src/data_generation/mock/generate_brewery.cc src/data_generation/mock/generate_brewery.cc
src/data_generation/mock/generate_user.cc src/data_generation/mock/generate_user.cc
src/json_handling/json_loader.cc src/json_handling/json_loader.cc
src/services/sqlite/helpers/sqlite_connection_helpers.cpp src/services/sqlite/helpers/sqlite_connection_helpers.cpp
src/services/sqlite/helpers/sqlite_statement_helpers.cpp src/services/sqlite/helpers/sqlite_statement_helpers.cpp
) )
# ============================================================================= # =============================================================================
@@ -173,6 +173,6 @@ configure_file(
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_SOURCE_DIR}/prompts ${CMAKE_SOURCE_DIR}/prompts
${CMAKE_BINARY_DIR}/prompts ${CMAKE_BINARY_DIR}/prompts
) )

View File

@@ -17,6 +17,7 @@
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
#include "data_generation/prompt_formatting/prompt_formatter.h" #include "data_generation/prompt_formatting/prompt_formatter.h"
#include "data_model/application_options.h" #include "data_model/application_options.h"
#include "services/prompt_directory.h"
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
@@ -33,10 +34,12 @@ class LlamaGenerator final : public DataGenerator {
* @param options Parsed application options. * @param options Parsed application options.
* @param model_path Filesystem path to GGUF model assets. * @param model_path Filesystem path to GGUF model assets.
* @param prompt_formatter Formatter that produces model-specific prompts. * @param prompt_formatter Formatter that produces model-specific prompts.
* @param prompt_directory Directory service for loading named prompt files.
*/ */
LlamaGenerator(const ApplicationOptions& options, LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path, const std::string& model_path,
std::unique_ptr<IPromptFormatter> prompt_formatter); std::unique_ptr<IPromptFormatter> prompt_formatter,
std::unique_ptr<IPromptDirectory> prompt_directory);
~LlamaGenerator() override; ~LlamaGenerator() override;
@@ -119,15 +122,6 @@ class LlamaGenerator final : public DataGenerator {
int max_tokens = kDefaultMaxTokens, int max_tokens = kDefaultMaxTokens,
std::string_view grammar = {}); std::string_view grammar = {});
/**
* @brief Loads the brewery system prompt from disk.
*
* @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text.
*/
std::string LoadBrewerySystemPrompt(
const std::filesystem::path& prompt_file_path);
ModelHandle model_; ModelHandle model_;
ContextHandle context_; ContextHandle context_;
float sampling_temperature_ = 1.0F; float sampling_temperature_ = 1.0F;
@@ -135,8 +129,8 @@ class LlamaGenerator final : public DataGenerator {
uint32_t sampling_top_k_ = kDefaultSamplingTopK; uint32_t sampling_top_k_ = kDefaultSamplingTopK;
std::mt19937 rng_; std::mt19937 rng_;
uint32_t n_ctx_ = kDefaultContextSize; uint32_t n_ctx_ = kDefaultContextSize;
std::string brewery_system_prompt_;
std::unique_ptr<IPromptFormatter> prompt_formatter_; std::unique_ptr<IPromptFormatter> prompt_formatter_;
std::unique_ptr<IPromptDirectory> prompt_directory_;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -7,36 +7,66 @@
*/ */
#include <cstdint> #include <cstdint>
#include <filesystem>
#include <optional>
#include <string> #include <string>
/** /**
* @brief Program options for the Biergarten pipeline application. * @brief LLM sampling parameters.
*/ */
struct ApplicationOptions { struct SamplingOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked.
std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path.
bool use_mocked = false;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 1.0F; float temperature = 1.0F;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more /// @brief LLM nucleus sampling top-p parameter.
/// random).
float top_p = 0.95F; float top_p = 0.95F;
/// @brief LLM top-k sampling parameter. /// @brief LLM top-k sampling parameter.
uint32_t top_k = 64; uint32_t top_k = 64;
/// @brief Context window size (tokens) for LLM inference. Higher values /// @brief Context window size (tokens).
/// support longer prompts but use more memory.
uint32_t n_ctx = 8192; uint32_t n_ctx = 8192;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;
}; };
/**
* @brief Configuration for the LLM generator component.
*/
struct GeneratorOptions {
/// @brief Path to the LLM model file (gguf format).
std::filesystem::path model_path;
/// @brief Use mocked generator instead of actual LLM inference.
bool use_mocked = false;
/// @brief Specific sampling parameters for this generator.
/// If nullopt, the application should use global defaults.
std::optional<SamplingOptions> sampling;
};
/**
* @brief Configuration for the pipeline execution and output.
*/
struct PipelineOptions {
/// @brief Directory for generated artifacts.
std::filesystem::path output_path;
/// @brief Directory that contains named prompt files (e.g.
/// BREWERY_GENERATION.md).
std::filesystem::path prompt_dir;
/// @brief Path for application logs.
std::filesystem::path log_path;
};
/**
* @brief Root configuration object for the Biergarten pipeline.
*/
struct ApplicationOptions {
GeneratorOptions generator;
PipelineOptions pipeline;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_

View File

@@ -6,6 +6,8 @@
* @brief Abstraction for persisting generated brewery data. * @brief Abstraction for persisting generated brewery data.
*/ */
#include <cstdint>
#include "data_model/generated_brewery.h" #include "data_model/generated_brewery.h"
/** /**

View File

@@ -0,0 +1,76 @@
#ifndef BIERGARTEN_PIPELINE_INCLUDES_SERVICES_PROMPT_DIRECTORY_H_
#define BIERGARTEN_PIPELINE_INCLUDES_SERVICES_PROMPT_DIRECTORY_H_
/**
* @file services/prompt_directory.h
* @brief Interface and filesystem-backed implementation for named prompt
* loading.
*
* Prompt files are resolved by key: a key of "BREWERY_GENERATION" maps to the
* file <prompt_dir>/BREWERY_GENERATION.md. The interface is kept intentionally
* narrow so test doubles can be injected without touching the filesystem.
*/
#include <filesystem>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
/**
* @brief Interface for loading named prompt files.
*/
class IPromptDirectory {
public:
IPromptDirectory() = default;
IPromptDirectory(const IPromptDirectory&) = delete;
IPromptDirectory& operator=(const IPromptDirectory&) = delete;
IPromptDirectory(IPromptDirectory&&) = delete;
IPromptDirectory& operator=(IPromptDirectory&&) = delete;
virtual ~IPromptDirectory() = default;
/**
* @brief Loads the prompt associated with @p key.
*
* @param key Logical prompt key, e.g. "BREWERY_GENERATION".
* @return Prompt text.
* @throws std::runtime_error if the prompt file cannot be found or read.
*/
[[nodiscard]] virtual std::string Load(std::string_view key) = 0;
};
/**
* @brief Filesystem-backed IPromptDirectory implementation.
*
* Each call to Load() checks an in-process cache first, then reads
* <prompt_dir>/<key>.md from disk. The directory must exist and be readable
* at construction time; individual file absence is reported lazily at Load().
*/
class PromptDirectory final : public IPromptDirectory {
public:
/**
* @brief Constructs a PromptDirectory rooted at @p prompt_dir.
*
* @param prompt_dir Absolute or relative path to the prompt directory.
* @throws std::runtime_error if @p prompt_dir does not exist or is not a
* directory.
*/
explicit PromptDirectory(const std::filesystem::path& prompt_dir);
/**
* @brief Loads the prompt for @p key, caching the result.
*
* Maps @p key → <prompt_dir>/<key>.md.
*
* @param key Logical prompt key.
* @return Prompt text.
* @throws std::runtime_error if the file does not exist or is empty.
*/
[[nodiscard]] std::string Load(std::string_view key) override;
private:
std::filesystem::path prompt_dir_;
std::unordered_map<std::string, std::string> cache_;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_PROMPT_DIRECTORY_H_

View File

@@ -11,6 +11,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "data_model/application_options.h"
#include "services/date_time_provider.h" #include "services/date_time_provider.h"
#include "services/export_service.h" #include "services/export_service.h"
#include "services/sqlite_export_service_helpers.h" #include "services/sqlite_export_service_helpers.h"
@@ -20,7 +21,7 @@
*/ */
class SqliteExportService final : public IExportService { class SqliteExportService final : public IExportService {
public: public:
SqliteExportService(); explicit SqliteExportService(const ApplicationOptions& options);
~SqliteExportService() override; ~SqliteExportService() override;
SqliteExportService(const SqliteExportService&) = delete; SqliteExportService(const SqliteExportService&) = delete;
@@ -47,6 +48,7 @@ class SqliteExportService final : public IExportService {
[[nodiscard]] static std::string BuildLocationKey(const Location& location); [[nodiscard]] static std::string BuildLocationKey(const Location& location);
std::unique_ptr<IDateTimeProvider> date_time_provider_; std::unique_ptr<IDateTimeProvider> date_time_provider_;
std::filesystem::path output_path_;
std::string run_timestamp_utc_; std::string run_timestamp_utc_;
std::filesystem::path database_path_; std::filesystem::path database_path_;
SqliteDatabaseHandle db_handle_; SqliteDatabaseHandle db_handle_;

View File

@@ -59,11 +59,12 @@ BreweryResult LlamaGenerator::GenerateBrewery(
location.country.empty() ? std::string{} location.country.empty() ? std::string{}
: std::format(", {}", location.country); : std::format(", {}", location.country);
/** /**
* Load brewery system prompt from file * Load brewery system prompt via the injected prompt directory.
* Falls back to minimal inline prompt if file not found * The key "BREWERY_GENERATION" resolves to BREWERY_GENERATION.md inside
* the configured --prompt-dir. Throws on missing or empty file.
*/ */
const std::string system_prompt = const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md"); prompt_directory_->Load("BREWERY_GENERATION");
std::string user_prompt = std::format( std::string user_prompt = std::format(
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## LOCAL LANGUAGE CODES:\n{}\n\n## " "## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## LOCAL LANGUAGE CODES:\n{}\n\n## "

View File

@@ -32,9 +32,11 @@ void LlamaGenerator::ContextDeleter::operator()(
LlamaGenerator::LlamaGenerator( LlamaGenerator::LlamaGenerator(
const ApplicationOptions& options, const std::string& model_path, const ApplicationOptions& options, const std::string& model_path,
std::unique_ptr<IPromptFormatter> prompt_formatter) std::unique_ptr<IPromptFormatter> prompt_formatter,
std::unique_ptr<IPromptDirectory> prompt_directory)
: rng_(std::random_device{}()), : rng_(std::random_device{}()),
prompt_formatter_(std::move(prompt_formatter)) { prompt_formatter_(std::move(prompt_formatter)),
prompt_directory_(std::move(prompt_directory)) {
if (model_path.empty()) { if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty"); throw std::runtime_error("LlamaGenerator: model path must not be empty");
} }
@@ -44,41 +46,49 @@ LlamaGenerator::LlamaGenerator(
"LlamaGenerator: prompt formatter dependency must not be null"); "LlamaGenerator: prompt formatter dependency must not be null");
} }
if (options.temperature < 0.0F) { if (!prompt_directory_) {
throw std::runtime_error(
"LlamaGenerator: prompt directory dependency must not be null");
}
const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
if (sampling.temperature < 0.0F) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0"); "LlamaGenerator: sampling temperature must be >= 0");
} }
if (options.top_p <= 0.0F || options.top_p > 1.0F) { if (sampling.top_p <= 0.0F || sampling.top_p > 1.0F) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]"); "LlamaGenerator: sampling top-p must be in (0, 1]");
} }
if (options.top_k == 0U) { if (sampling.top_k == 0U) {
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0"); throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
} }
if (options.seed < -1) { if (sampling.seed < -1) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random"); "LlamaGenerator: seed must be >= 0, or -1 for random");
} }
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) { if (sampling.n_ctx == 0 || sampling.n_ctx > kMaxContextSize) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]"); "LlamaGenerator: context size must be in range [1, 32768]");
} }
sampling_temperature_ = options.temperature; sampling_temperature_ = sampling.temperature;
sampling_top_p_ = options.top_p; sampling_top_p_ = sampling.top_p;
sampling_top_k_ = options.top_k; sampling_top_k_ = sampling.top_k;
if (options.seed == -1) { if (sampling.seed == -1) {
std::random_device random_device; std::random_device random_device;
rng_.seed(random_device()); rng_.seed(random_device());
} else { } else {
rng_.seed(static_cast<uint32_t>(options.seed)); rng_.seed(static_cast<uint32_t>(sampling.seed));
} }
n_ctx_ = options.n_ctx;
n_ctx_ = sampling.n_ctx;
this->Load(model_path); this->Load(model_path);
} }

View File

@@ -1,55 +0,0 @@
/**
* @file data_generation/llama/load_brewery_prompt.cc
* @brief Resolves brewery system prompt content from cache or a configured
* filesystem path and provides a robust inline fallback prompt when absent.
*/
#include <spdlog/spdlog.h>
#include <filesystem>
#include <fstream>
#include <stdexcept>
#include "data_generation/llama_generator.h"
/**
* @brief Loads brewery system prompt from disk or cache.
*
* @param prompt_file_path Preferred prompt file location.
* @return Prompt text loaded from disk.
*/
std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::filesystem::path& prompt_file_path) {
// Return cached version if already loaded
if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_;
}
std::ifstream prompt_file(prompt_file_path);
if (!prompt_file.is_open()) {
spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_file_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_file_path.string());
}
const std::string prompt((std::istreambuf_iterator(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
if (prompt.empty()) {
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_file_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_file_path.string());
}
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
prompt_file_path.string(), prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}

View File

@@ -9,6 +9,7 @@
#include <boost/di.hpp> #include <boost/di.hpp>
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include <chrono> #include <chrono>
#include <cstdint>
#include <exception> #include <exception>
#include <memory> #include <memory>
#include <optional> #include <optional>
@@ -23,6 +24,7 @@
#include "llama_backend_state.h" #include "llama_backend_state.h"
#include "services/enrichment_service.h" #include "services/enrichment_service.h"
#include "services/export_service.h" #include "services/export_service.h"
#include "services/prompt_directory.h"
#include "services/sqlite_export_service.h" #include "services/sqlite_export_service.h"
#include "services/wikipedia_service.h" #include "services/wikipedia_service.h"
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
@@ -45,28 +47,34 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
opt("help,h", "Produce help message"); opt("help,h", "Produce help message");
// Generator Options
opt("mocked", prog_opts::bool_switch(), opt("mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data"); "Use mocked generator for brewery/user data");
opt("model,m", prog_opts::value<std::string>()->default_value(""), opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)"); "Path to LLM model (gguf)");
// Sampling Options
opt("temperature", prog_opts::value<float>()->default_value(1.0F), opt("temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)"); "Sampling temperature (higher = more random)");
opt("top-p", prog_opts::value<float>()->default_value(0.95F), opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)"); "Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64), opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)"); "Top-k sampling parameter (higher = more candidate tokens)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192), opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)"); "Context window size in tokens");
opt("seed", prog_opts::value<int>()->default_value(-1), opt("seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer"); "Sampler seed: -1 for random, otherwise non-negative integer");
// Handle the "no arguments" or "help" case // Pipeline Options
opt("output,o", prog_opts::value<std::string>()->default_value("output"),
"Directory for generated artifacts");
opt("log-path",
prog_opts::value<std::string>()->default_value("pipeline.log"),
"Path for application logs");
opt("prompt-dir", prog_opts::value<std::string>()->default_value(""),
"Directory containing named prompt files (e.g. BREWERY_GENERATION.md)."
" Required when not using --mocked.");
if (argc == 1) { if (argc == 1) {
spdlog::info("Biergarten Pipeline"); spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream; std::stringstream usage_stream;
@@ -76,20 +84,25 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
} }
try { try {
prog_opts::variables_map variables_map; prog_opts::variables_map vm;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), vm);
variables_map); prog_opts::notify(vm);
prog_opts::notify(variables_map);
if (variables_map.contains("help")) { if (vm.contains("help")) {
std::stringstream help_stream; std::stringstream help_stream;
help_stream << "\n" << desc; help_stream << "\n" << desc;
spdlog::info(help_stream.str()); spdlog::info(help_stream.str());
return std::nullopt; return std::nullopt;
} }
const auto use_mocked = variables_map["mocked"].as<bool>(); ApplicationOptions options;
const auto model_path = variables_map["model"].as<std::string>();
options.pipeline.output_path = vm["output"].as<std::string>();
options.pipeline.log_path = vm["log-path"].as<std::string>();
options.pipeline.prompt_dir = vm["prompt-dir"].as<std::string>();
const bool use_mocked = vm["mocked"].as<bool>();
const std::string model_path = vm["model"].as<std::string>();
if (use_mocked && !model_path.empty()) { if (use_mocked && !model_path.empty()) {
spdlog::error( spdlog::error(
@@ -103,25 +116,35 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
return std::nullopt; return std::nullopt;
} }
const bool has_llm_params = !variables_map["temperature"].defaulted() || if (!use_mocked && options.pipeline.prompt_dir.empty()) {
!variables_map["top-p"].defaulted() || spdlog::error(
!variables_map["top-k"].defaulted() || "Invalid arguments: --prompt-dir is required when not using "
!variables_map["seed"].defaulted(); "--mocked");
return std::nullopt;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
} }
ApplicationOptions options; options.generator.use_mocked = use_mocked;
options.use_mocked = use_mocked; options.generator.model_path = model_path;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>(); const bool user_provided_sampling =
options.top_p = variables_map["top-p"].as<float>(); !vm["temperature"].defaulted() || !vm["top-p"].defaulted() ||
options.top_k = variables_map["top-k"].as<uint32_t>(); !vm["top-k"].defaulted() || !vm["n-ctx"].defaulted() ||
options.n_ctx = variables_map["n-ctx"].as<uint32_t>(); !vm["seed"].defaulted();
options.seed = variables_map["seed"].as<int>();
if (use_mocked) {
if (user_provided_sampling) {
spdlog::warn("Sampling parameters are ignored when using --mocked");
}
} else if (user_provided_sampling) {
SamplingOptions sampling;
sampling.temperature = vm["temperature"].as<float>();
sampling.top_p = vm["top-p"].as<float>();
sampling.top_k = vm["top-k"].as<uint32_t>();
sampling.n_ctx = vm["n-ctx"].as<uint32_t>();
sampling.seed = vm["seed"].as<int>();
options.generator.sampling = sampling;
}
return options; return options;
} catch (const std::exception& exception) { } catch (const std::exception& exception) {
@@ -157,6 +180,22 @@ int main(const int argc, char** argv) {
} }
const auto options = *parsed_options; const auto options = *parsed_options;
const std::string model_path = options.generator.model_path.string();
const auto sampling =
options.generator.sampling.value_or(SamplingOptions{});
// Scenario 4: Validate the prompt directory up-front, before any DI
// wiring, so the error surfaces immediately with a clear message.
std::unique_ptr<IPromptDirectory> prompt_directory;
if (!options.generator.use_mocked) {
try {
prompt_directory =
std::make_unique<PromptDirectory>(options.pipeline.prompt_dir);
} catch (const std::exception& dir_error) {
spdlog::error("[Startup] Invalid --prompt-dir: {}", dir_error.what());
return 1;
}
}
const auto injector = di::make_injector( const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(), di::bind<WebClient>().to<CURLWebClient>(),
@@ -164,10 +203,11 @@ int main(const int argc, char** argv) {
di::bind<IEnrichmentService>().to<WikipediaService>(), di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<IExportService>().to<SqliteExportService>(), di::bind<IExportService>().to<SqliteExportService>(),
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(), di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
di::bind<std::string>().to(options.model_path), di::bind<std::string>().to(model_path),
di::bind<DataGenerator>().to( di::bind<DataGenerator>().to(
[options](const auto& inj) -> std::unique_ptr<DataGenerator> { [options, model_path, sampling, &prompt_directory](
if (options.use_mocked) { const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) {
spdlog::info( spdlog::info(
"[Generator] Using MockGenerator (no model path provided)"); "[Generator] Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>(); return std::make_unique<MockGenerator>();
@@ -176,9 +216,15 @@ int main(const int argc, char** argv) {
spdlog::info( spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, " "[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})", "top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p, model_path, sampling.temperature, sampling.top_p,
options.top_k, options.n_ctx, options.seed); sampling.top_k, sampling.n_ctx, sampling.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>(); // Transfer ownership of the pre-validated PromptDirectory into
// the LlamaGenerator. The lambda captures by reference so the
// unique_ptr is moved exactly once.
return std::make_unique<LlamaGenerator>(
options, model_path,
inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory));
})); }));
auto generator = auto generator =

View File

@@ -0,0 +1,85 @@
/**
* @file services/prompt_directory.cc
* @brief PromptDirectory implementation: validates the directory at
* construction and loads named prompt files on demand with in-process caching.
*/
#include "services/prompt_directory.h"
#include <spdlog/spdlog.h>
#include <filesystem>
#include <fstream>
#include <stdexcept>
#include <string>
#include <string_view>
// ---------------------------------------------------------------------------
// PromptDirectory
// ---------------------------------------------------------------------------
PromptDirectory::PromptDirectory(const std::filesystem::path& prompt_dir)
: prompt_dir_(prompt_dir) {
std::error_code ec;
// Scenario 4: directory must exist.
if (!std::filesystem::exists(prompt_dir_, ec) || ec) {
throw std::runtime_error(
"PromptDirectory: prompt directory does not exist: " +
prompt_dir_.string());
}
// Scenario 4: path must be a directory, not a file.
if (!std::filesystem::is_directory(prompt_dir_, ec) || ec) {
throw std::runtime_error(
"PromptDirectory: prompt directory path is not a directory: " +
prompt_dir_.string());
}
// Scenario 4: directory must be readable (probe with directory_iterator).
std::filesystem::directory_iterator probe(prompt_dir_, ec);
if (ec) {
throw std::runtime_error(
"PromptDirectory: prompt directory is not readable: " +
prompt_dir_.string() + " (" + ec.message() + ")");
}
spdlog::info("[PromptDirectory] Resolved prompt directory: {}",
prompt_dir_.string());
}
std::string PromptDirectory::Load(std::string_view key) {
const std::string key_str(key);
// Return cached content if already loaded during this run.
const auto cache_it = cache_.find(key_str);
if (cache_it != cache_.end()) {
return cache_it->second;
}
// Scenario 3: resolve <prompt_dir>/<key>.md and require it to exist.
const std::filesystem::path file_path =
prompt_dir_ / std::filesystem::path(key_str + ".md");
std::ifstream file(file_path);
if (!file.is_open()) {
throw std::runtime_error(
"PromptDirectory: prompt file not found for key '" + key_str +
"': " + file_path.string());
}
std::string content((std::istreambuf_iterator<char>(file)),
std::istreambuf_iterator<char>());
file.close();
if (content.empty()) {
throw std::runtime_error("PromptDirectory: prompt file for key '" +
key_str + "' is empty: " + file_path.string());
}
spdlog::info("[PromptDirectory] Loaded prompt '{}' from '{}' ({} chars)",
key_str, file_path.string(), content.size());
cache_.emplace(key_str, content);
return content;
}

View File

@@ -11,11 +11,10 @@
std::filesystem::path SqliteExportService::BuildDatabasePath() const { std::filesystem::path SqliteExportService::BuildDatabasePath() const {
std::filesystem::path base_filename("biergarten_seed_" + run_timestamp_utc_ + std::filesystem::path base_filename("biergarten_seed_" + run_timestamp_utc_ +
".sqlite"); ".sqlite");
std::filesystem::path candidate = std::filesystem::path candidate = output_path_ / base_filename;
std::filesystem::current_path() / base_filename;
for (int suffix = 1; std::filesystem::exists(candidate); ++suffix) { for (int suffix = 1; std::filesystem::exists(candidate); ++suffix) {
candidate = std::filesystem::current_path() / candidate = output_path_ /
std::filesystem::path("biergarten_seed_" + run_timestamp_utc_ + std::filesystem::path("biergarten_seed_" + run_timestamp_utc_ +
"-" + std::to_string(suffix) + ".sqlite"); "-" + std::to_string(suffix) + ".sqlite");
} }

View File

@@ -7,11 +7,12 @@
#include <memory> #include <memory>
SqliteExportService::SqliteExportService() SqliteExportService::SqliteExportService(const ApplicationOptions& options)
: date_time_provider_(std::make_unique<SystemDateTimeProvider>()) {} : date_time_provider_(std::make_unique<SystemDateTimeProvider>()),
output_path_(options.pipeline.output_path) {}
SqliteExportService::~SqliteExportService() { SqliteExportService::~SqliteExportService() {
if (db_handle_ != nullptr) { if (db_handle_ != nullptr) {
RollbackAndCloseNoThrow(); RollbackAndCloseNoThrow();
} }
} }