mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Refactor BiergartenDataGenerator to use dependency injection container
This commit is contained in:
53
pipeline/src/data_generation/llama/constructor.cpp
Normal file
53
pipeline/src/data_generation/llama/constructor.cpp
Normal file
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* @file data_generation/llama/constructor.cpp
|
||||
* @brief LlamaGenerator constructor implementation.
|
||||
*/
|
||||
|
||||
#include <llama.h>
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "biergarten_data_generator.h"
|
||||
#include "data_generation/llama_generator.h"
|
||||
|
||||
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
const std::string& model_path) {
|
||||
if (model_path.empty()) {
|
||||
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
||||
}
|
||||
|
||||
if (options.temperature < 0.0F) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: sampling temperature must be >= 0");
|
||||
}
|
||||
|
||||
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: sampling top-p must be in (0, 1]");
|
||||
}
|
||||
|
||||
if (options.seed < -1) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: seed must be >= 0, or -1 for random");
|
||||
}
|
||||
|
||||
if (options.n_ctx == 0 || options.n_ctx > 32768) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: context size must be in range [1, 32768]");
|
||||
}
|
||||
|
||||
sampling_temperature_ = options.temperature;
|
||||
sampling_top_p_ = options.top_p;
|
||||
sampling_seed_ = (options.seed < 0)
|
||||
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
|
||||
: static_cast<uint32_t>(options.seed);
|
||||
n_ctx_ = options.n_ctx;
|
||||
|
||||
try {
|
||||
Load(model_path);
|
||||
} catch (...) {
|
||||
llama_backend_free();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
/**
|
||||
* @file data_generation/llama/load.cpp
|
||||
* @brief Initializes llama backend, loads model weights, creates inference
|
||||
* context, and resets prior resources during model reload.
|
||||
* context, and resets prior resources during model initialization.
|
||||
*/
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
@@ -13,12 +13,6 @@
|
||||
#include "llama.h"
|
||||
|
||||
void LlamaGenerator::Load(const std::string& model_path) {
|
||||
/**
|
||||
* Validate input and clean up any previously loaded model/context
|
||||
*/
|
||||
if (model_path.empty())
|
||||
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
||||
|
||||
if (context_ != nullptr) {
|
||||
llama_free(context_);
|
||||
context_ = nullptr;
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
/**
|
||||
* @file data_generation/llama/set_sampling_options.cpp
|
||||
* @brief Validates and stores sampling temperature, top-p, seed, and context
|
||||
* size configuration used by subsequent LlamaGenerator inference calls.
|
||||
*/
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "data_generation/llama_generator.h"
|
||||
#include "llama.h"
|
||||
|
||||
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
|
||||
int seed) {
|
||||
/**
|
||||
* Validate temperature: controls randomness in output distribution
|
||||
* 0.0 = deterministic (always pick highest probability token)
|
||||
* Higher values = more random/diverse output
|
||||
*/
|
||||
if (temperature < 0.0f) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: sampling temperature must be >= 0");
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate top-p (nucleus sampling): only sample from top cumulative
|
||||
* probability e.g., top-p=0.9 means sample from tokens that make up 90% of
|
||||
* probability mass
|
||||
*/
|
||||
if (!(top_p > 0.0f && top_p <= 1.0f)) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: sampling top-p must be in (0, 1]");
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate seed: for reproducible results (-1 uses random seed)
|
||||
*/
|
||||
if (seed < -1) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: seed must be >= 0, or -1 for random");
|
||||
}
|
||||
|
||||
/**
|
||||
* Store sampling parameters for use during token generation
|
||||
*/
|
||||
sampling_temperature_ = temperature;
|
||||
sampling_top_p_ = top_p;
|
||||
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
|
||||
: static_cast<uint32_t>(seed);
|
||||
}
|
||||
|
||||
void LlamaGenerator::SetContextSize(uint32_t n_ctx) {
|
||||
/**
|
||||
* Validate context size: must be positive and reasonable for the model
|
||||
*/
|
||||
if (n_ctx == 0 || n_ctx > 32768) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: context size must be in range [1, 32768]");
|
||||
}
|
||||
|
||||
/**
|
||||
* Store context size for use during model loading
|
||||
*/
|
||||
n_ctx_ = n_ctx;
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
/**
|
||||
* @file data_generation/mock/load.cpp
|
||||
* @brief Provides MockGenerator initialization behavior, which is a no-op load
|
||||
* path that logs readiness without model resources.
|
||||
*/
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "data_generation/mock_generator.h"
|
||||
|
||||
void MockGenerator::Load(const std::string& /*modelPath*/) {
|
||||
spdlog::info("[MockGenerator] No model needed");
|
||||
}
|
||||
Reference in New Issue
Block a user