mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
fix: llama backend lifetime, Wikipedia enrichment depth, and misc cleanup
This commit is contained in:
@@ -3,8 +3,7 @@
|
||||
* @brief LlamaGenerator constructor implementation.
|
||||
*/
|
||||
|
||||
#include <llama.h>
|
||||
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
@@ -12,7 +11,8 @@
|
||||
#include "data_generation/llama_generator.h"
|
||||
|
||||
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
const std::string& model_path) {
|
||||
const std::string& model_path)
|
||||
: rng_() {
|
||||
if (model_path.empty()) {
|
||||
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
||||
}
|
||||
@@ -39,15 +39,13 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
|
||||
sampling_temperature_ = options.temperature;
|
||||
sampling_top_p_ = options.top_p;
|
||||
sampling_seed_ = (options.seed < 0)
|
||||
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
|
||||
: static_cast<uint32_t>(options.seed);
|
||||
if (options.seed == -1) {
|
||||
std::random_device random_device;
|
||||
rng_.seed(random_device());
|
||||
} else {
|
||||
rng_.seed(static_cast<uint32_t>(options.seed));
|
||||
}
|
||||
n_ctx_ = options.n_ctx;
|
||||
|
||||
try {
|
||||
Load(model_path);
|
||||
} catch (...) {
|
||||
llama_backend_free();
|
||||
throw;
|
||||
}
|
||||
Load(model_path);
|
||||
}
|
||||
|
||||
@@ -23,9 +23,4 @@ LlamaGenerator::~LlamaGenerator() {
|
||||
llama_model_free(model_);
|
||||
model_ = nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up the backend (GPU/CPU acceleration resources)
|
||||
*/
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
@@ -145,8 +145,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
* Distribution sampler: selects actual token using configured seed for
|
||||
* reproducibility
|
||||
*/
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_dist(sampling_seed_));
|
||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
|
||||
|
||||
/**
|
||||
* TOKEN GENERATION LOOP
|
||||
@@ -187,10 +186,5 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
for (const llama_token token : generated_tokens)
|
||||
AppendTokenPiecePublic(vocab, token, output);
|
||||
|
||||
/**
|
||||
* Advance seed for next generation to improve output diversity
|
||||
*/
|
||||
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
@@ -22,11 +23,6 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
||||
model_ = nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the llama backend (one-time setup for GPU/CPU acceleration)
|
||||
*/
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
|
||||
if (model_ == nullptr) {
|
||||
@@ -36,7 +32,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
||||
|
||||
llama_context_params context_params = llama_context_default_params();
|
||||
context_params.n_ctx = n_ctx_;
|
||||
context_params.n_batch = n_ctx_; // Set batch size equal to context window
|
||||
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(512));
|
||||
|
||||
context_ = llama_init_from_model(model_, context_params);
|
||||
if (context_ == nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user