fix: llama backend lifetime, Wikipedia enrichment depth, and misc cleanup

This commit is contained in:
Aaron Po
2026-04-09 21:59:13 -04:00
parent 824f5b2b4f
commit b53f9e5582
17 changed files with 161 additions and 104 deletions

View File

@@ -3,8 +3,7 @@
* @brief LlamaGenerator constructor implementation.
*/
#include <llama.h>
#include <random>
#include <stdexcept>
#include <string>
@@ -12,7 +11,8 @@
#include "data_generation/llama_generator.h"
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) {
const std::string& model_path)
: rng_() {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
@@ -39,15 +39,13 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_seed_ = (options.seed < 0)
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(options.seed);
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
}
n_ctx_ = options.n_ctx;
try {
Load(model_path);
} catch (...) {
llama_backend_free();
throw;
}
Load(model_path);
}

View File

@@ -23,9 +23,4 @@ LlamaGenerator::~LlamaGenerator() {
llama_model_free(model_);
model_ = nullptr;
}
/**
* Clean up the backend (GPU/CPU acceleration resources)
*/
llama_backend_free();
}

View File

@@ -145,8 +145,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
/**
* TOKEN GENERATION LOOP
@@ -187,10 +186,5 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
/**
* Advance seed for next generation to improve output diversity
*/
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
return output;
}

View File

@@ -6,6 +6,7 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
@@ -22,11 +23,6 @@ void LlamaGenerator::Load(const std::string& model_path) {
model_ = nullptr;
}
/**
* Initialize the llama backend (one-time setup for GPU/CPU acceleration)
*/
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
@@ -36,7 +32,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_;
context_params.n_batch = n_ctx_; // Set batch size equal to context window
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(512));
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {