Refactor BiergartenDataGenerator to use dependency injection container

This commit is contained in:
Aaron Po
2026-04-09 20:33:48 -04:00
parent 5d93d76e99
commit 824f5b2b4f
23 changed files with 332 additions and 394 deletions

View File

@@ -0,0 +1,53 @@
/**
* @file data_generation/llama/constructor.cpp
* @brief LlamaGenerator constructor implementation.
*/
#include <llama.h>
#include <stdexcept>
#include <string>
#include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h"
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
if (options.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_seed_ = (options.seed < 0)
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(options.seed);
n_ctx_ = options.n_ctx;
try {
Load(model_path);
} catch (...) {
llama_backend_free();
throw;
}
}

View File

@@ -1,7 +1,7 @@
/**
* @file data_generation/llama/load.cpp
* @brief Initializes llama backend, loads model weights, creates inference
* context, and resets prior resources during model reload.
* context, and resets prior resources during model initialization.
*/
#include <spdlog/spdlog.h>
@@ -13,12 +13,6 @@
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
/**
* Validate input and clean up any previously loaded model/context
*/
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;

View File

@@ -1,64 +0,0 @@
/**
* @file data_generation/llama/set_sampling_options.cpp
* @brief Validates and stores sampling temperature, top-p, seed, and context
* size configuration used by subsequent LlamaGenerator inference calls.
*/
#include <stdexcept>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
/**
* Validate temperature: controls randomness in output distribution
* 0.0 = deterministic (always pick highest probability token)
* Higher values = more random/diverse output
*/
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
/**
* Validate top-p (nucleus sampling): only sample from top cumulative
* probability e.g., top-p=0.9 means sample from tokens that make up 90% of
* probability mass
*/
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
/**
* Validate seed: for reproducible results (-1 uses random seed)
*/
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
/**
* Store sampling parameters for use during token generation
*/
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::SetContextSize(uint32_t n_ctx) {
/**
* Validate context size: must be positive and reasonable for the model
*/
if (n_ctx == 0 || n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
/**
* Store context size for use during model loading
*/
n_ctx_ = n_ctx;
}