refactor(pipeline): restructure config, add PromptDirectory, consolidate SQLite layer (#217)

* Refactor ApplicationOptions to separate config concerns

* add prompt dir app option

* readability updates: remove magic numbers, update comments

* codebase formatting

* Update docs

* Extract argument parsing, timer out of
This commit is contained in:
2026-05-02 18:27:14 -04:00
committed by GitHub
parent 641a479b6a
commit b1dc8e0b5d
35 changed files with 561 additions and 310 deletions

View File

@@ -32,9 +32,11 @@ void LlamaGenerator::ContextDeleter::operator()(
LlamaGenerator::LlamaGenerator(
const ApplicationOptions& options, const std::string& model_path,
std::unique_ptr<IPromptFormatter> prompt_formatter)
std::unique_ptr<IPromptFormatter> prompt_formatter,
std::unique_ptr<IPromptDirectory> prompt_directory)
: rng_(std::random_device{}()),
prompt_formatter_(std::move(prompt_formatter)) {
prompt_formatter_(std::move(prompt_formatter)),
prompt_directory_(std::move(prompt_directory)) {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
@@ -44,41 +46,49 @@ LlamaGenerator::LlamaGenerator(
"LlamaGenerator: prompt formatter dependency must not be null");
}
if (options.temperature < 0.0F) {
if (!prompt_directory_) {
throw std::runtime_error(
"LlamaGenerator: prompt directory dependency must not be null");
}
const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
if (sampling.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
if (sampling.top_p <= 0.0F || sampling.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.top_k == 0U) {
if (sampling.top_k == 0U) {
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
}
if (options.seed < -1) {
if (sampling.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) {
if (sampling.n_ctx == 0 || sampling.n_ctx > kMaxContextSize) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_top_k_ = options.top_k;
sampling_temperature_ = sampling.temperature;
sampling_top_p_ = sampling.top_p;
sampling_top_k_ = sampling.top_k;
if (options.seed == -1) {
if (sampling.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
rng_.seed(static_cast<uint32_t>(sampling.seed));
}
n_ctx_ = options.n_ctx;
n_ctx_ = sampling.n_ctx;
this->Load(model_path);
}