mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Use unique_ptr with custom deleter for llama
This commit is contained in:
@@ -22,7 +22,7 @@ static constexpr std::size_t kPromptTokenSlack = 8;
|
||||
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
||||
const std::string& prompt,
|
||||
const int max_tokens) {
|
||||
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
|
||||
return InferFormatted(ToChatPromptPublic(model_.get(), system_prompt, prompt),
|
||||
max_tokens);
|
||||
}
|
||||
|
||||
@@ -31,14 +31,14 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
/**
|
||||
* Validate that model and context are loaded
|
||||
*/
|
||||
if (model_ == nullptr || context_ == nullptr) {
|
||||
if (!model_ || !context_) {
|
||||
throw std::runtime_error("LlamaGenerator: model not loaded");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get vocabulary for tokenization and token-to-text conversion
|
||||
*/
|
||||
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
||||
const llama_vocab* vocab = llama_model_get_vocab(model_.get());
|
||||
if (vocab == nullptr) {
|
||||
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
||||
}
|
||||
@@ -46,7 +46,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
/**
|
||||
* Clear KV cache to ensure clean inference state (no residual context)
|
||||
*/
|
||||
llama_memory_clear(llama_get_memory(context_), true);
|
||||
llama_memory_clear(llama_get_memory(context_.get()), true);
|
||||
|
||||
/**
|
||||
* TOKENIZATION PHASE
|
||||
@@ -79,8 +79,8 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
* Validate and compute effective token budgets based on context window
|
||||
* constraints
|
||||
*/
|
||||
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
|
||||
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_));
|
||||
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_.get()));
|
||||
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_.get()));
|
||||
if (n_ctx <= 1 || n_batch <= 0) {
|
||||
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
||||
}
|
||||
@@ -117,7 +117,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
*/
|
||||
const llama_batch prompt_batch = llama_batch_get_one(
|
||||
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
|
||||
if (llama_decode(context_, prompt_batch) != 0) {
|
||||
if (llama_decode(context_.get(), prompt_batch) != 0) {
|
||||
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
* Index -1 means use the last output position from previous batch
|
||||
*/
|
||||
const llama_token next =
|
||||
llama_sampler_sample(sampler_->chain, context_, -1);
|
||||
llama_sampler_sample(sampler_->chain, context_.get(), -1);
|
||||
/**
|
||||
* Stop if model predicts end-of-generation token (EOS/EOT)
|
||||
*/
|
||||
@@ -153,7 +153,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
*/
|
||||
llama_token decode_token = next;
|
||||
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
|
||||
if (llama_decode(context_, one_token_batch) != 0) {
|
||||
if (llama_decode(context_.get(), one_token_batch) != 0) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: decode failed during generation");
|
||||
}
|
||||
|
||||
@@ -24,6 +24,18 @@ struct SamplerConfig {
|
||||
using SamplerPtr =
|
||||
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
||||
|
||||
void LlamaGenerator::ModelDeleter::operator()(llama_model* model) const noexcept {
|
||||
if (model != nullptr) {
|
||||
llama_model_free(model);
|
||||
}
|
||||
}
|
||||
|
||||
void LlamaGenerator::ContextDeleter::operator()(llama_context* context) const noexcept {
|
||||
if (context != nullptr) {
|
||||
llama_free(context);
|
||||
}
|
||||
}
|
||||
|
||||
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
|
||||
std::mt19937& rng) {
|
||||
const llama_sampler_chain_params sampler_params =
|
||||
@@ -88,6 +100,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
sampling_temperature_ = options.temperature;
|
||||
sampling_top_p_ = options.top_p;
|
||||
sampling_top_k_ = options.top_k;
|
||||
|
||||
if (options.seed == -1) {
|
||||
std::random_device random_device;
|
||||
rng_.seed(random_device());
|
||||
@@ -100,26 +113,8 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
|
||||
sampling_top_k_};
|
||||
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
|
||||
sampler_.reset(new SamplerState());
|
||||
sampler_ = std::make_unique<SamplerState>();
|
||||
sampler_->chain = sampler_chain.release();
|
||||
}
|
||||
|
||||
LlamaGenerator::~LlamaGenerator() {
|
||||
sampler_.reset();
|
||||
|
||||
/**
|
||||
* Free the inference context (contains KV cache and computation state)
|
||||
*/
|
||||
if (context_ != nullptr) {
|
||||
llama_free(context_);
|
||||
context_ = nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Free the loaded model (contains weights and vocabulary)
|
||||
*/
|
||||
if (model_ != nullptr) {
|
||||
llama_model_free(model_);
|
||||
model_ = nullptr;
|
||||
}
|
||||
}
|
||||
LlamaGenerator::~LlamaGenerator() = default;
|
||||
|
||||
@@ -9,23 +9,19 @@
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "data_generation/llama_generator.h"
|
||||
#include "llama.h"
|
||||
|
||||
void LlamaGenerator::Load(const std::string& model_path) {
|
||||
if (context_ != nullptr) {
|
||||
llama_free(context_);
|
||||
context_ = nullptr;
|
||||
}
|
||||
if (model_ != nullptr) {
|
||||
llama_model_free(model_);
|
||||
model_ = nullptr;
|
||||
}
|
||||
context_.reset();
|
||||
model_.reset();
|
||||
|
||||
const llama_model_params model_params = llama_model_default_params();
|
||||
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
|
||||
if (model_ == nullptr) {
|
||||
LlamaGenerator::ModelHandle loaded_model(
|
||||
llama_model_load_from_file(model_path.c_str(), model_params));
|
||||
if (!loaded_model) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: failed to load model from path: " + model_path);
|
||||
}
|
||||
@@ -34,12 +30,14 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
||||
context_params.n_ctx = n_ctx_;
|
||||
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
|
||||
|
||||
context_ = llama_init_from_model(model_, context_params);
|
||||
if (context_ == nullptr) {
|
||||
llama_model_free(model_);
|
||||
model_ = nullptr;
|
||||
LlamaGenerator::ContextHandle loaded_context(
|
||||
llama_init_from_model(loaded_model.get(), context_params));
|
||||
if (!loaded_context) {
|
||||
throw std::runtime_error("LlamaGenerator: failed to create context");
|
||||
}
|
||||
|
||||
model_ = std::move(loaded_model);
|
||||
context_ = std::move(loaded_context);
|
||||
|
||||
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user