2 Commits

Author SHA1 Message Date
Aaron Po
15853c62fd remove const to enable use of std::move 2026-04-13 22:02:31 -04:00
Aaron Po
ff4b7f2578 Use unique_ptr with custom deleter for llama 2026-04-13 21:45:00 -04:00
5 changed files with 57 additions and 50 deletions

View File

@@ -7,6 +7,7 @@
*/
#include <cstdint>
#include <memory>
#include <random>
#include <string>
#include <string_view>
@@ -65,6 +66,17 @@ class LlamaGenerator final : public DataGenerator {
static constexpr uint32_t kDefaultSamplingTopK = 64;
static constexpr uint32_t kDefaultContextSize = 8192;
struct ModelDeleter {
void operator()(llama_model* model) const noexcept;
};
struct ContextDeleter {
void operator()(llama_context* context) const noexcept;
};
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
struct SamplerState {
SamplerState() = default;
~SamplerState();
@@ -116,8 +128,8 @@ class LlamaGenerator final : public DataGenerator {
*/
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
llama_model* model_ = nullptr;
llama_context* context_ = nullptr;
ModelHandle model_;
ContextHandle context_;
/// @brief Persistent sampler chain reused across inference calls.
std::unique_ptr<SamplerState> sampler_;
float sampling_temperature_ = 1.0F;

View File

@@ -3,26 +3,28 @@
* @brief BiergartenDataGenerator::Run() implementation.
*/
#include <utility>
#include <spdlog/spdlog.h>
#include "biergarten_data_generator.h"
bool BiergartenDataGenerator::Run() {
try {
const std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size());
size_t skipped_count = 0;
for (const auto& city : cities) {
for (auto& city : cities) {
try {
const std::string region_context =
context_service_->GetLocationContext(city);
std::string region_context = context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
enriched.push_back(
EnrichedCity{.location = city, .region_context = region_context});
EnrichedCity{.location = std::move(city),
.region_context = std::move(region_context)});
} catch (const std::exception& exception) {
++skipped_count;
spdlog::warn(

View File

@@ -22,7 +22,7 @@ static constexpr std::size_t kPromptTokenSlack = 8;
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt,
const int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
return InferFormatted(ToChatPromptPublic(model_.get(), system_prompt, prompt),
max_tokens);
}
@@ -31,14 +31,14 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
/**
* Validate that model and context are loaded
*/
if (model_ == nullptr || context_ == nullptr) {
if (!model_ || !context_) {
throw std::runtime_error("LlamaGenerator: model not loaded");
}
/**
* Get vocabulary for tokenization and token-to-text conversion
*/
const llama_vocab* vocab = llama_model_get_vocab(model_);
const llama_vocab* vocab = llama_model_get_vocab(model_.get());
if (vocab == nullptr) {
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
@@ -46,7 +46,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
/**
* Clear KV cache to ensure clean inference state (no residual context)
*/
llama_memory_clear(llama_get_memory(context_), true);
llama_memory_clear(llama_get_memory(context_.get()), true);
/**
* TOKENIZATION PHASE
@@ -79,8 +79,8 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
* Validate and compute effective token budgets based on context window
* constraints
*/
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_));
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_.get()));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_.get()));
if (n_ctx <= 1 || n_batch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
@@ -117,7 +117,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
*/
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) {
if (llama_decode(context_.get(), prompt_batch) != 0) {
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
@@ -139,7 +139,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
* Index -1 means use the last output position from previous batch
*/
const llama_token next =
llama_sampler_sample(sampler_->chain, context_, -1);
llama_sampler_sample(sampler_->chain, context_.get(), -1);
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
@@ -153,7 +153,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
*/
llama_token decode_token = next;
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
if (llama_decode(context_, one_token_batch) != 0) {
if (llama_decode(context_.get(), one_token_batch) != 0) {
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}

View File

@@ -24,6 +24,18 @@ struct SamplerConfig {
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
void LlamaGenerator::ModelDeleter::operator()(llama_model* model) const noexcept {
if (model != nullptr) {
llama_model_free(model);
}
}
void LlamaGenerator::ContextDeleter::operator()(llama_context* context) const noexcept {
if (context != nullptr) {
llama_free(context);
}
}
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
std::mt19937& rng) {
const llama_sampler_chain_params sampler_params =
@@ -88,6 +100,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_top_k_ = options.top_k;
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
@@ -100,26 +113,8 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
sampling_top_k_};
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
sampler_.reset(new SamplerState());
sampler_ = std::make_unique<SamplerState>();
sampler_->chain = sampler_chain.release();
}
LlamaGenerator::~LlamaGenerator() {
sampler_.reset();
/**
* Free the inference context (contains KV cache and computation state)
*/
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
/**
* Free the loaded model (contains weights and vocabulary)
*/
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
}
LlamaGenerator::~LlamaGenerator() = default;

View File

@@ -9,23 +9,19 @@
#include <algorithm>
#include <stdexcept>
#include <string>
#include <utility>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
context_.reset();
model_.reset();
const llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
LlamaGenerator::ModelHandle loaded_model(
llama_model_load_from_file(model_path.c_str(), model_params));
if (!loaded_model) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
@@ -34,12 +30,14 @@ void LlamaGenerator::Load(const std::string& model_path) {
context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
LlamaGenerator::ContextHandle loaded_context(
llama_init_from_model(loaded_model.get(), context_params));
if (!loaded_context) {
throw std::runtime_error("LlamaGenerator: failed to create context");
}
model_ = std::move(loaded_model);
context_ = std::move(loaded_context);
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}