4 Commits

Author SHA1 Message Date
Aaron Po
271c6fa99f update docs 2026-05-01 18:30:38 -04:00
Aaron Po
316fda1775 codebase formatting 2026-05-01 17:40:37 -04:00
Aaron Po
91e18888fe readability updates: remove magic numbers, update comments 2026-05-01 17:38:16 -04:00
Aaron Po
9051f55114 add prompt dir app option 2026-05-01 12:25:05 -04:00
29 changed files with 299 additions and 137 deletions

View File

@@ -6,3 +6,4 @@ data
models models
*.gguf *.gguf
BiergartenPipeline.png BiergartenPipeline.png
output

View File

@@ -133,7 +133,7 @@ set(SOURCES
src/data_generation/llama/helpers.cc src/data_generation/llama/helpers.cc
src/data_generation/llama/infer.cc src/data_generation/llama/infer.cc
src/data_generation/llama/load.cc src/data_generation/llama/load.cc
src/data_generation/llama/load_brewery_prompt.cc src/services/prompt_directory.cc
src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc
src/data_generation/mock/deterministic_hash.cc src/data_generation/mock/deterministic_hash.cc
src/data_generation/mock/generate_brewery.cc src/data_generation/mock/generate_brewery.cc

View File

@@ -17,6 +17,7 @@
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
#include "data_generation/prompt_formatting/prompt_formatter.h" #include "data_generation/prompt_formatting/prompt_formatter.h"
#include "data_model/application_options.h" #include "data_model/application_options.h"
#include "services/prompt_directory.h"
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
@@ -33,10 +34,12 @@ class LlamaGenerator final : public DataGenerator {
* @param options Parsed application options. * @param options Parsed application options.
* @param model_path Filesystem path to GGUF model assets. * @param model_path Filesystem path to GGUF model assets.
* @param prompt_formatter Formatter that produces model-specific prompts. * @param prompt_formatter Formatter that produces model-specific prompts.
* @param prompt_directory Directory service for loading named prompt files.
*/ */
LlamaGenerator(const ApplicationOptions& options, LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path, const std::string& model_path,
std::unique_ptr<IPromptFormatter> prompt_formatter); std::unique_ptr<IPromptFormatter> prompt_formatter,
std::unique_ptr<IPromptDirectory> prompt_directory);
~LlamaGenerator() override; ~LlamaGenerator() override;
@@ -119,15 +122,6 @@ class LlamaGenerator final : public DataGenerator {
int max_tokens = kDefaultMaxTokens, int max_tokens = kDefaultMaxTokens,
std::string_view grammar = {}); std::string_view grammar = {});
/**
* @brief Loads the brewery system prompt from disk.
*
* @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text.
*/
std::string LoadBrewerySystemPrompt(
const std::filesystem::path& prompt_file_path);
ModelHandle model_; ModelHandle model_;
ContextHandle context_; ContextHandle context_;
float sampling_temperature_ = 1.0F; float sampling_temperature_ = 1.0F;
@@ -135,8 +129,8 @@ class LlamaGenerator final : public DataGenerator {
uint32_t sampling_top_k_ = kDefaultSamplingTopK; uint32_t sampling_top_k_ = kDefaultSamplingTopK;
std::mt19937 rng_; std::mt19937 rng_;
uint32_t n_ctx_ = kDefaultContextSize; uint32_t n_ctx_ = kDefaultContextSize;
std::string brewery_system_prompt_;
std::unique_ptr<IPromptFormatter> prompt_formatter_; std::unique_ptr<IPromptFormatter> prompt_formatter_;
std::unique_ptr<IPromptDirectory> prompt_directory_;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -44,6 +44,13 @@ class MockGenerator final : public DataGenerator {
*/ */
static size_t DeterministicHash(const Location& location); static size_t DeterministicHash(const Location& location);
// Hash stride constants for deterministic distribution across fixed-size
// arrays. These coprime strides spread hash values uniformly without
// clustering, ensuring diverse output across different hash inputs.
static constexpr size_t kNounHashStride = 7;
static constexpr size_t kDescriptionHashStride = 13;
static constexpr size_t kBioHashStride = 11;
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = { static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel", "Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",

View File

@@ -53,6 +53,10 @@ struct PipelineOptions {
/// @brief Directory for generated artifacts. /// @brief Directory for generated artifacts.
std::filesystem::path output_path; std::filesystem::path output_path;
/// @brief Directory that contains named prompt files (e.g.
/// BREWERY_GENERATION.md).
std::filesystem::path prompt_dir;
/// @brief Path for application logs. /// @brief Path for application logs.
std::filesystem::path log_path; std::filesystem::path log_path;
}; };

View File

@@ -0,0 +1,76 @@
#ifndef BIERGARTEN_PIPELINE_INCLUDES_SERVICES_PROMPT_DIRECTORY_H_
#define BIERGARTEN_PIPELINE_INCLUDES_SERVICES_PROMPT_DIRECTORY_H_
/**
* @file services/prompt_directory.h
* @brief Interface and filesystem-backed implementation for named prompt
* loading.
*
* Prompt files are resolved by key: a key of "BREWERY_GENERATION" maps to the
* file <prompt_dir>/BREWERY_GENERATION.md. The interface is kept intentionally
* narrow so test doubles can be injected without touching the filesystem.
*/
#include <filesystem>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
/**
* @brief Interface for loading named prompt files.
*/
class IPromptDirectory {
public:
IPromptDirectory() = default;
IPromptDirectory(const IPromptDirectory&) = delete;
IPromptDirectory& operator=(const IPromptDirectory&) = delete;
IPromptDirectory(IPromptDirectory&&) = delete;
IPromptDirectory& operator=(IPromptDirectory&&) = delete;
virtual ~IPromptDirectory() = default;
/**
* @brief Loads the prompt associated with @p key.
*
* @param key Logical prompt key, e.g. "BREWERY_GENERATION".
* @return Prompt text.
* @throws std::runtime_error if the prompt file cannot be found or read.
*/
[[nodiscard]] virtual std::string Load(std::string_view key) = 0;
};
/**
* @brief Filesystem-backed IPromptDirectory implementation.
*
* Each call to Load() checks an in-process cache first, then reads
* <prompt_dir>/<key>.md from disk. The directory must exist and be readable
* at construction time; individual file absence is reported lazily at Load().
*/
class PromptDirectory final : public IPromptDirectory {
public:
/**
* @brief Constructs a PromptDirectory rooted at @p prompt_dir.
*
* @param prompt_dir Absolute or relative path to the prompt directory.
* @throws std::runtime_error if @p prompt_dir does not exist or is not a
* directory.
*/
explicit PromptDirectory(const std::filesystem::path& prompt_dir);
/**
* @brief Loads the prompt for @p key, caching the result.
*
* Maps @p key → <prompt_dir>/<key>.md.
*
* @param key Logical prompt key.
* @return Prompt text.
* @throws std::runtime_error if the file does not exist or is empty.
*/
[[nodiscard]] std::string Load(std::string_view key) override;
private:
std::filesystem::path prompt_dir_;
std::unordered_map<std::string, std::string> cache_;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_PROMPT_DIRECTORY_H_

View File

@@ -7,6 +7,7 @@
*/ */
#include <sqlite3.h> #include <sqlite3.h>
#include <filesystem> #include <filesystem>
#include <string> #include <string>
#include <string_view> #include <string_view>
@@ -20,12 +21,10 @@ void ThrowSqliteError(sqlite3* db_handle, std::string_view action);
SqliteDatabaseHandle OpenDatabase(const std::filesystem::path& path); SqliteDatabaseHandle OpenDatabase(const std::filesystem::path& path);
void ExecSql(const SqliteDatabaseHandle& db_handle, std::string_view sql, void ExecSql(const SqliteDatabaseHandle& db_handle, std::string_view sql,
const char* action); const char* action);
void RollbackTransactionNoThrow(const SqliteDatabaseHandle& db_handle) noexcept; void RollbackTransactionNoThrow(const SqliteDatabaseHandle& db_handle) noexcept;
} // namespace sqlite_export_service_internal } // namespace sqlite_export_service_internal
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_CONNECTION_HELPERS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_CONNECTION_HELPERS_H_

View File

@@ -42,7 +42,6 @@ class SqliteExportService final : public IExportService {
void InitializeSchema() const; void InitializeSchema() const;
void PrepareStatements(); void PrepareStatements();
void RollbackAndCloseNoThrow() noexcept; void RollbackAndCloseNoThrow() noexcept;
void FinalizeStatements() noexcept;
[[nodiscard]] std::filesystem::path BuildDatabasePath() const; [[nodiscard]] std::filesystem::path BuildDatabasePath() const;
[[nodiscard]] static std::string BuildLocationKey(const Location& location); [[nodiscard]] static std::string BuildLocationKey(const Location& location);

View File

@@ -3,8 +3,8 @@
/* Umbrella header for backward compatibility. */ /* Umbrella header for backward compatibility. */
#include "services/sqlite_handle_types.h"
#include "services/sqlite_connection_helpers.h" #include "services/sqlite_connection_helpers.h"
#include "services/sqlite_handle_types.h"
#include "services/sqlite_statement_helpers.h" #include "services/sqlite_statement_helpers.h"
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_EXPORT_SERVICE_HELPERS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_EXPORT_SERVICE_HELPERS_H_

View File

@@ -6,6 +6,7 @@
*/ */
#include <sqlite3.h> #include <sqlite3.h>
#include <memory> #include <memory>
#include <string_view> #include <string_view>
@@ -33,4 +34,3 @@ struct BindParam {
} // namespace sqlite_export_service_internal } // namespace sqlite_export_service_internal
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_HANDLE_TYPES_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_HANDLE_TYPES_H_

View File

@@ -3,10 +3,12 @@
/** /**
* @file services/sqlite_statement_helpers.h * @file services/sqlite_statement_helpers.h
* @brief Declarations for statement-level SQLite helper functions and constants. * @brief Declarations for statement-level SQLite helper functions and
* constants.
*/ */
#include <sqlite3.h> #include <sqlite3.h>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector> #include <vector>
@@ -107,10 +109,8 @@ void StepStatement(const SqliteDatabaseHandle& db_handle,
sqlite3_int64 LastInsertRowId(const SqliteDatabaseHandle& db_handle); sqlite3_int64 LastInsertRowId(const SqliteDatabaseHandle& db_handle);
std::string SerializeLocalLanguages(const std::vector<std::string>& local_languages);
std::string SerializeVector(const std::vector<std::string>& str_vec); std::string SerializeVector(const std::vector<std::string>& str_vec);
} // namespace sqlite_export_service_internal } // namespace sqlite_export_service_internal
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_STATEMENT_HELPERS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_SQLITE_STATEMENT_HELPERS_H_

View File

@@ -33,6 +33,9 @@ static std::string FormatLocalLanguageCodes(
return formatted; return formatted;
} }
// GBNF grammar for structured brewery JSON output.
// @TODO move to a separate gbnf file if it grows in complexity or is shared
// across modules.
static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery( static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery(
root ::= thought-block "{" ws "\"name_en\"" ws ":" ws string ws "," ws "\"description_en\"" ws ":" ws string ws "," ws "\"name_local\"" ws ":" ws string ws "," ws "\"description_local\"" ws ":" ws string ws "}" ws root ::= thought-block "{" ws "\"name_en\"" ws ":" ws string ws "," ws "\"description_en\"" ws ":" ws string ws "," ws "\"name_local\"" ws ":" ws string ws "," ws "\"description_local\"" ws ":" ws string ws "}" ws
thought-block ::= [^{]* thought-block ::= [^{]*
@@ -59,11 +62,12 @@ BreweryResult LlamaGenerator::GenerateBrewery(
location.country.empty() ? std::string{} location.country.empty() ? std::string{}
: std::format(", {}", location.country); : std::format(", {}", location.country);
/** /**
* Load brewery system prompt from file * Load brewery system prompt via the injected prompt directory.
* Falls back to minimal inline prompt if file not found * The key "BREWERY_GENERATION" resolves to BREWERY_GENERATION.md inside
* the configured --prompt-dir. Throws on missing or empty file.
*/ */
const std::string system_prompt = const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md"); prompt_directory_->Load("BREWERY_GENERATION");
std::string user_prompt = std::format( std::string user_prompt = std::format(
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## LOCAL LANGUAGE CODES:\n{}\n\n## " "## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## LOCAL LANGUAGE CODES:\n{}\n\n## "

View File

@@ -12,6 +12,13 @@
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
// TODO: Implement locale-aware user profile generation.
// Current implementation returns a hardcoded test value and ignores the
// locale parameter. Future implementation should:
// 1. Load a USER_GENERATION.md prompt template with locale context
// 2. Perform LLM inference with locale-specific username/bio generation
// 3. Parse and validate JSON output with retry handling (similar to brewery)
// 4. Return locale-aware username and biography
UserResult LlamaGenerator::GenerateUser(const std::string& locale) { UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
return {.username = "test_user", return {.username = "test_user",
.bio = "This is a test user profile from " + locale + "."}; .bio = "This is a test user profile from " + locale + "."};

View File

@@ -58,6 +58,11 @@ static std::string CondenseWhitespace(std::string_view text) {
return out; return out;
} }
// Guard against truncating in the first half of the string.
// This preserves the critical opening content and avoids cutting critical
// context words early in the region description.
static constexpr size_t kTruncationGuardDivisor = 2;
/** /**
* Truncate region context to fit within max length while preserving word * Truncate region context to fit within max length while preserving word
* boundaries * boundaries
@@ -71,7 +76,8 @@ std::string PrepareRegionContext(std::string_view region_context,
normalized.resize(max_chars); normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' '); const size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) { if (last_space != std::string::npos &&
last_space > max_chars / kTruncationGuardDivisor) {
normalized.resize(last_space); normalized.resize(last_space);
} }

View File

@@ -19,6 +19,9 @@
#include "llama.h" #include "llama.h"
static constexpr size_t kPromptTokenSlack = 8; static constexpr size_t kPromptTokenSlack = 8;
// Minimum tokens to keep when using top-p sampling. Ensures at least one
// candidate token remains available even with very restrictive top-p values.
static constexpr size_t kTopPMinKeep = 1;
namespace { namespace {
@@ -62,7 +65,7 @@ SamplerHandle MakeSamplerChain(const llama_vocab* vocab,
"LlamaGenerator: failed to initialize temperature sampler"); "LlamaGenerator: failed to initialize temperature sampler");
add_sampler(llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)), add_sampler(llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)),
"LlamaGenerator: failed to initialize top-k sampler"); "LlamaGenerator: failed to initialize top-k sampler");
add_sampler(llama_sampler_init_top_p(config.top_p, 1), add_sampler(llama_sampler_init_top_p(config.top_p, kTopPMinKeep),
"LlamaGenerator: failed to initialize top-p sampler"); "LlamaGenerator: failed to initialize top-p sampler");
add_sampler(llama_sampler_init_dist(config.seed), add_sampler(llama_sampler_init_dist(config.seed),
"LlamaGenerator: failed to initialize distribution sampler"); "LlamaGenerator: failed to initialize distribution sampler");

View File

@@ -32,9 +32,11 @@ void LlamaGenerator::ContextDeleter::operator()(
LlamaGenerator::LlamaGenerator( LlamaGenerator::LlamaGenerator(
const ApplicationOptions& options, const std::string& model_path, const ApplicationOptions& options, const std::string& model_path,
std::unique_ptr<IPromptFormatter> prompt_formatter) std::unique_ptr<IPromptFormatter> prompt_formatter,
std::unique_ptr<IPromptDirectory> prompt_directory)
: rng_(std::random_device{}()), : rng_(std::random_device{}()),
prompt_formatter_(std::move(prompt_formatter)) { prompt_formatter_(std::move(prompt_formatter)),
prompt_directory_(std::move(prompt_directory)) {
if (model_path.empty()) { if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty"); throw std::runtime_error("LlamaGenerator: model path must not be empty");
} }
@@ -44,6 +46,11 @@ LlamaGenerator::LlamaGenerator(
"LlamaGenerator: prompt formatter dependency must not be null"); "LlamaGenerator: prompt formatter dependency must not be null");
} }
if (!prompt_directory_) {
throw std::runtime_error(
"LlamaGenerator: prompt directory dependency must not be null");
}
const auto sampling = options.generator.sampling.value_or(SamplingOptions{}); const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
if (sampling.temperature < 0.0F) { if (sampling.temperature < 0.0F) {

View File

@@ -14,6 +14,10 @@
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h" #include "llama.h"
// Maximum batch size for decode operations. Capping the batch prevents
// excessive memory allocation while maintaining inference performance.
static constexpr uint32_t kMaxBatchSize = 5000U;
void LlamaGenerator::Load(const std::string& model_path) { void LlamaGenerator::Load(const std::string& model_path) {
context_.reset(); context_.reset();
model_.reset(); model_.reset();
@@ -28,7 +32,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
llama_context_params context_params = llama_context_default_params(); llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_; context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000)); context_params.n_batch = std::min(n_ctx_, kMaxBatchSize);
LlamaGenerator::ContextHandle loaded_context( LlamaGenerator::ContextHandle loaded_context(
llama_init_from_model(loaded_model.get(), context_params)); llama_init_from_model(loaded_model.get(), context_params));

View File

@@ -1,55 +0,0 @@
/**
* @file data_generation/llama/load_brewery_prompt.cc
* @brief Resolves brewery system prompt content from cache or a configured
* filesystem path and provides a robust inline fallback prompt when absent.
*/
#include <spdlog/spdlog.h>
#include <filesystem>
#include <fstream>
#include <stdexcept>
#include "data_generation/llama_generator.h"
/**
* @brief Loads brewery system prompt from disk or cache.
*
* @param prompt_file_path Preferred prompt file location.
* @return Prompt text loaded from disk.
*/
std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::filesystem::path& prompt_file_path) {
// Return cached version if already loaded
if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_;
}
std::ifstream prompt_file(prompt_file_path);
if (!prompt_file.is_open()) {
spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_file_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_file_path.string());
}
const std::string prompt((std::istreambuf_iterator(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
if (prompt.empty()) {
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_file_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_file_path.string());
}
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
prompt_file_path.string(), prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}

View File

@@ -17,9 +17,9 @@ BreweryResult MockGenerator::GenerateBrewery(
const std::string_view adjective = const std::string_view adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size()); kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun = const std::string_view noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size()); kBreweryNouns.at(hash / kNounHashStride % kBreweryNouns.size());
const std::string_view base_description = const std::string_view base_description = kBreweryDescriptions.at(
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size()); (hash / kDescriptionHashStride) % kBreweryDescriptions.size());
const std::string name = const std::string name =
std::format("{} {} {}", location.city, adjective, noun); std::format("{} {} {}", location.city, adjective, noun);

View File

@@ -15,7 +15,7 @@ UserResult MockGenerator::GenerateUser(const std::string& locale) {
UserResult result; UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()]; const std::string_view username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()]; const std::string_view bio = kBios[hash / kBioHashStride % kBios.size()];
result.username = username; result.username = username;
result.bio = bio; result.bio = bio;
return result; return result;

View File

@@ -24,6 +24,7 @@
#include "llama_backend_state.h" #include "llama_backend_state.h"
#include "services/enrichment_service.h" #include "services/enrichment_service.h"
#include "services/export_service.h" #include "services/export_service.h"
#include "services/prompt_directory.h"
#include "services/sqlite_export_service.h" #include "services/sqlite_export_service.h"
#include "services/wikipedia_service.h" #include "services/wikipedia_service.h"
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
@@ -52,16 +53,21 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
opt("model,m", prog_opts::value<std::string>()->default_value(""), opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)"); "Path to LLM model (gguf)");
// Sampling Options // Sampling Options - defaults driven from SamplingOptions struct
opt("temperature", prog_opts::value<float>()->default_value(1.0F), const SamplingOptions kSamplingDefaults{};
opt("temperature",
prog_opts::value<float>()->default_value(kSamplingDefaults.temperature),
"Sampling temperature (higher = more random)"); "Sampling temperature (higher = more random)");
opt("top-p", prog_opts::value<float>()->default_value(0.95F), opt("top-p",
prog_opts::value<float>()->default_value(kSamplingDefaults.top_p),
"Nucleus sampling top-p in (0,1] (higher = more random)"); "Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64), opt("top-k",
prog_opts::value<uint32_t>()->default_value(kSamplingDefaults.top_k),
"Top-k sampling parameter (higher = more candidate tokens)"); "Top-k sampling parameter (higher = more candidate tokens)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192), opt("n-ctx",
prog_opts::value<uint32_t>()->default_value(kSamplingDefaults.n_ctx),
"Context window size in tokens"); "Context window size in tokens");
opt("seed", prog_opts::value<int>()->default_value(-1), opt("seed", prog_opts::value<int>()->default_value(kSamplingDefaults.seed),
"Sampler seed: -1 for random, otherwise non-negative integer"); "Sampler seed: -1 for random, otherwise non-negative integer");
// Pipeline Options // Pipeline Options
@@ -70,6 +76,9 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
opt("log-path", opt("log-path",
prog_opts::value<std::string>()->default_value("pipeline.log"), prog_opts::value<std::string>()->default_value("pipeline.log"),
"Path for application logs"); "Path for application logs");
opt("prompt-dir", prog_opts::value<std::string>()->default_value(""),
"Directory containing named prompt files (e.g. BREWERY_GENERATION.md)."
" Required when not using --mocked.");
if (argc == 1) { if (argc == 1) {
spdlog::info("Biergarten Pipeline"); spdlog::info("Biergarten Pipeline");
@@ -80,11 +89,11 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
} }
try { try {
prog_opts::variables_map vm; prog_opts::variables_map var_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), vm); prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), var_map);
prog_opts::notify(vm); prog_opts::notify(var_map);
if (vm.contains("help")) { if (var_map.contains("help")) {
std::stringstream help_stream; std::stringstream help_stream;
help_stream << "\n" << desc; help_stream << "\n" << desc;
spdlog::info(help_stream.str()); spdlog::info(help_stream.str());
@@ -93,11 +102,12 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
ApplicationOptions options; ApplicationOptions options;
options.pipeline.output_path = vm["output"].as<std::string>(); options.pipeline.output_path = var_map["output"].as<std::string>();
options.pipeline.log_path = vm["log-path"].as<std::string>(); options.pipeline.log_path = var_map["log-path"].as<std::string>();
options.pipeline.prompt_dir = var_map["prompt-dir"].as<std::string>();
const bool use_mocked = vm["mocked"].as<bool>(); const bool use_mocked = var_map["mocked"].as<bool>();
const std::string model_path = vm["model"].as<std::string>(); const std::string model_path = var_map["model"].as<std::string>();
if (use_mocked && !model_path.empty()) { if (use_mocked && !model_path.empty()) {
spdlog::error( spdlog::error(
@@ -111,13 +121,20 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
return std::nullopt; return std::nullopt;
} }
if (!use_mocked && options.pipeline.prompt_dir.empty()) {
spdlog::error(
"Invalid arguments: --prompt-dir is required when not using "
"--mocked");
return std::nullopt;
}
options.generator.use_mocked = use_mocked; options.generator.use_mocked = use_mocked;
options.generator.model_path = model_path; options.generator.model_path = model_path;
const bool user_provided_sampling = const bool user_provided_sampling =
!vm["temperature"].defaulted() || !vm["top-p"].defaulted() || !var_map["temperature"].defaulted() || !var_map["top-p"].defaulted() ||
!vm["top-k"].defaulted() || !vm["n-ctx"].defaulted() || !var_map["top-k"].defaulted() || !var_map["n-ctx"].defaulted() ||
!vm["seed"].defaulted(); !var_map["seed"].defaulted();
if (use_mocked) { if (use_mocked) {
if (user_provided_sampling) { if (user_provided_sampling) {
@@ -125,11 +142,11 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
} }
} else if (user_provided_sampling) { } else if (user_provided_sampling) {
SamplingOptions sampling; SamplingOptions sampling;
sampling.temperature = vm["temperature"].as<float>(); sampling.temperature = var_map["temperature"].as<float>();
sampling.top_p = vm["top-p"].as<float>(); sampling.top_p = var_map["top-p"].as<float>();
sampling.top_k = vm["top-k"].as<uint32_t>(); sampling.top_k = var_map["top-k"].as<uint32_t>();
sampling.n_ctx = vm["n-ctx"].as<uint32_t>(); sampling.n_ctx = var_map["n-ctx"].as<uint32_t>();
sampling.seed = vm["seed"].as<int>(); sampling.seed = var_map["seed"].as<int>();
options.generator.sampling = sampling; options.generator.sampling = sampling;
} }
@@ -172,6 +189,17 @@ int main(const int argc, char** argv) {
const auto sampling = const auto sampling =
options.generator.sampling.value_or(SamplingOptions{}); options.generator.sampling.value_or(SamplingOptions{});
std::unique_ptr<IPromptDirectory> prompt_directory;
if (!options.generator.use_mocked) {
try {
prompt_directory =
std::make_unique<PromptDirectory>(options.pipeline.prompt_dir);
} catch (const std::exception& dir_error) {
spdlog::error("[Startup] Invalid --prompt-dir: {}", dir_error.what());
return 1;
}
}
const auto injector = di::make_injector( const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(), di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options), di::bind<ApplicationOptions>().to(options),
@@ -180,8 +208,8 @@ int main(const int argc, char** argv) {
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(), di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
di::bind<std::string>().to(model_path), di::bind<std::string>().to(model_path),
di::bind<DataGenerator>().to( di::bind<DataGenerator>().to(
[options, model_path, [options, model_path, sampling, &prompt_directory](
sampling](const auto& inj) -> std::unique_ptr<DataGenerator> { const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) { if (options.generator.use_mocked) {
spdlog::info( spdlog::info(
"[Generator] Using MockGenerator (no model path provided)"); "[Generator] Using MockGenerator (no model path provided)");
@@ -193,7 +221,10 @@ int main(const int argc, char** argv) {
"top-p={}, top-k={}, n_ctx={}, seed={})", "top-p={}, top-k={}, n_ctx={}, seed={})",
model_path, sampling.temperature, sampling.top_p, model_path, sampling.temperature, sampling.top_p,
sampling.top_k, sampling.n_ctx, sampling.seed); sampling.top_k, sampling.n_ctx, sampling.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>(); return std::make_unique<LlamaGenerator>(
options, model_path,
inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory));
})); }));
auto generator = auto generator =

View File

@@ -0,0 +1,85 @@
/**
* @file services/prompt_directory.cc
* @brief PromptDirectory implementation: validates the directory at
* construction and loads named prompt files on demand with in-process caching.
*/
#include "services/prompt_directory.h"
#include <spdlog/spdlog.h>
#include <filesystem>
#include <fstream>
#include <stdexcept>
#include <string>
#include <string_view>
// ---------------------------------------------------------------------------
// PromptDirectory
// ---------------------------------------------------------------------------
PromptDirectory::PromptDirectory(const std::filesystem::path& prompt_dir)
: prompt_dir_(prompt_dir) {
std::error_code ec;
// Scenario 4: directory must exist.
if (!std::filesystem::exists(prompt_dir_, ec) || ec) {
throw std::runtime_error(
"PromptDirectory: prompt directory does not exist: " +
prompt_dir_.string());
}
// Scenario 4: path must be a directory, not a file.
if (!std::filesystem::is_directory(prompt_dir_, ec) || ec) {
throw std::runtime_error(
"PromptDirectory: prompt directory path is not a directory: " +
prompt_dir_.string());
}
// Scenario 4: directory must be readable (probe with directory_iterator).
std::filesystem::directory_iterator probe(prompt_dir_, ec);
if (ec) {
throw std::runtime_error(
"PromptDirectory: prompt directory is not readable: " +
prompt_dir_.string() + " (" + ec.message() + ")");
}
spdlog::info("[PromptDirectory] Resolved prompt directory: {}",
prompt_dir_.string());
}
std::string PromptDirectory::Load(std::string_view key) {
const std::string key_str(key);
// Return cached content if already loaded during this run.
const auto cache_it = cache_.find(key_str);
if (cache_it != cache_.end()) {
return cache_it->second;
}
// Scenario 3: resolve <prompt_dir>/<key>.md and require it to exist.
const std::filesystem::path file_path =
prompt_dir_ / std::filesystem::path(key_str + ".md");
std::ifstream file(file_path);
if (!file.is_open()) {
throw std::runtime_error(
"PromptDirectory: prompt file not found for key '" + key_str +
"': " + file_path.string());
}
std::string content((std::istreambuf_iterator<char>(file)),
std::istreambuf_iterator<char>());
file.close();
if (content.empty()) {
throw std::runtime_error("PromptDirectory: prompt file for key '" +
key_str + "' is empty: " + file_path.string());
}
spdlog::info("[PromptDirectory] Loaded prompt '{}' from '{}' ({} chars)",
key_str, file_path.string(), content.size());
cache_.emplace(key_str, content);
return content;
}

View File

@@ -8,7 +8,6 @@
#include "services/sqlite_export_service.h" #include "services/sqlite_export_service.h"
#include "services/sqlite_export_service_helpers.h" #include "services/sqlite_export_service_helpers.h"
void SqliteExportService::Finalize() { void SqliteExportService::Finalize() {
if (db_handle_ == nullptr) { if (db_handle_ == nullptr) {
return; return;

View File

@@ -10,7 +10,8 @@ void SqliteDatabaseDeleter::operator()(sqlite3* handle) const noexcept {
} }
} }
void SqliteStatementDeleter::operator()(sqlite3_stmt* statement) const noexcept { void SqliteStatementDeleter::operator()(
sqlite3_stmt* statement) const noexcept {
if (statement != nullptr) { if (statement != nullptr) {
sqlite3_finalize(statement); sqlite3_finalize(statement);
} }
@@ -23,7 +24,6 @@ void ThrowSqliteError(sqlite3* db_handle, std::string_view action) {
} }
SqliteDatabaseHandle OpenDatabase(const std::filesystem::path& path) { SqliteDatabaseHandle OpenDatabase(const std::filesystem::path& path) {
sqlite3* raw_handle = nullptr; sqlite3* raw_handle = nullptr;
const int result = sqlite3_open(path.string().c_str(), &raw_handle); const int result = sqlite3_open(path.string().c_str(), &raw_handle);
@@ -54,7 +54,8 @@ void ExecSql(const SqliteDatabaseHandle& db_handle, std::string_view sql,
} }
} }
void RollbackTransactionNoThrow(const SqliteDatabaseHandle& db_handle) noexcept { void RollbackTransactionNoThrow(
const SqliteDatabaseHandle& db_handle) noexcept {
if (!db_handle) { if (!db_handle) {
return; return;
} }
@@ -63,4 +64,3 @@ void RollbackTransactionNoThrow(const SqliteDatabaseHandle& db_handle) noexcept
} }
} // namespace sqlite_export_service_internal } // namespace sqlite_export_service_internal

View File

@@ -1,11 +1,12 @@
#include "services/sqlite_statement_helpers.h" #include "services/sqlite_statement_helpers.h"
#include "services/sqlite_connection_helpers.h"
#include <cstring>
#include <memory>
#include <limits>
#include <stdexcept>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <cstring>
#include <limits>
#include <memory>
#include <stdexcept>
#include "services/sqlite_connection_helpers.h"
namespace sqlite_export_service_internal { namespace sqlite_export_service_internal {
@@ -86,16 +87,6 @@ sqlite3_int64 LastInsertRowId(const SqliteDatabaseHandle& db_handle) {
return sqlite3_last_insert_rowid(db_handle.get()); return sqlite3_last_insert_rowid(db_handle.get());
} }
std::string SerializeLocalLanguages(
const std::vector<std::string>& local_languages) {
boost::json::array array;
array.reserve(local_languages.size());
for (const auto& language : local_languages) {
array.emplace_back(language);
}
return boost::json::serialize(array);
}
std::string SerializeVector(const std::vector<std::string>& str_vec) { std::string SerializeVector(const std::vector<std::string>& str_vec) {
boost::json::array array(str_vec.size()); boost::json::array array(str_vec.size());
for (const auto& s : str_vec) { for (const auto& s : str_vec) {
@@ -105,4 +96,3 @@ std::string SerializeVector(const std::vector<std::string>& str_vec) {
} }
} // namespace sqlite_export_service_internal } // namespace sqlite_export_service_internal

View File

@@ -11,7 +11,6 @@
#include "services/sqlite_export_service.h" #include "services/sqlite_export_service.h"
#include "services/sqlite_export_service_helpers.h" #include "services/sqlite_export_service_helpers.h"
void SqliteExportService::InitializeSchema() const { void SqliteExportService::InitializeSchema() const {
sqlite_export_service_internal::ExecSql( sqlite_export_service_internal::ExecSql(
db_handle_, sqlite_export_service_internal::kCreateLocationsTableSql, db_handle_, sqlite_export_service_internal::kCreateLocationsTableSql,
@@ -46,7 +45,6 @@ void SqliteExportService::RollbackAndCloseNoThrow() noexcept {
location_cache_.clear(); location_cache_.clear();
} }
void SqliteExportService::Initialize() { void SqliteExportService::Initialize() {
if (db_handle_ != nullptr) { if (db_handle_ != nullptr) {
throw std::runtime_error("SQLite export service is already initialized"); throw std::runtime_error("SQLite export service is already initialized");

View File

@@ -3,6 +3,8 @@
* @brief SqliteExportService::ProcessRecord() implementation. * @brief SqliteExportService::ProcessRecord() implementation.
*/ */
#include <iomanip>
#include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>

View File

@@ -17,6 +17,7 @@ using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static constexpr long kConnectionTimeout = 10; static constexpr long kConnectionTimeout = 10;
static constexpr long kRequestTimeout = 30; static constexpr long kRequestTimeout = 30;
static constexpr long kMaxRedirects = 5;
static constexpr int32_t kOkHttpStatus = 200; static constexpr int32_t kOkHttpStatus = 200;
static CurlHandle CreateHandle() { static CurlHandle CreateHandle() {
@@ -32,7 +33,7 @@ static void SetCommonGetOptions(CURL* curl, const std::string& url) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); curl_easy_setopt(curl, CURLOPT_MAXREDIRS, kMaxRedirects);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, kConnectionTimeout); curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, kConnectionTimeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, kRequestTimeout); curl_easy_setopt(curl, CURLOPT_TIMEOUT, kRequestTimeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");