diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index 07661d7..e7e9901 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -65,6 +66,17 @@ class LlamaGenerator final : public DataGenerator { static constexpr uint32_t kDefaultSamplingTopK = 64; static constexpr uint32_t kDefaultContextSize = 8192; + struct ModelDeleter { + void operator()(llama_model* model) const noexcept; + }; + + struct ContextDeleter { + void operator()(llama_context* context) const noexcept; + }; + + using ModelHandle = std::unique_ptr; + using ContextHandle = std::unique_ptr; + struct SamplerState { SamplerState() = default; ~SamplerState(); @@ -116,8 +128,8 @@ class LlamaGenerator final : public DataGenerator { */ std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path); - llama_model* model_ = nullptr; - llama_context* context_ = nullptr; + ModelHandle model_; + ContextHandle context_; /// @brief Persistent sampler chain reused across inference calls. std::unique_ptr sampler_; float sampling_temperature_ = 1.0F; diff --git a/pipeline/src/data_generation/llama/infer.cc b/pipeline/src/data_generation/llama/infer.cc index c10fb8a..77e4787 100644 --- a/pipeline/src/data_generation/llama/infer.cc +++ b/pipeline/src/data_generation/llama/infer.cc @@ -22,7 +22,7 @@ static constexpr std::size_t kPromptTokenSlack = 8; std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, const int max_tokens) { - return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt), + return InferFormatted(ToChatPromptPublic(model_.get(), system_prompt, prompt), max_tokens); } @@ -31,14 +31,14 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, /** * Validate that model and context are loaded */ - if (model_ == nullptr || context_ == nullptr) { + if (!model_ || !context_) { throw std::runtime_error("LlamaGenerator: model not loaded"); } /** * Get vocabulary for tokenization and token-to-text conversion */ - const llama_vocab* vocab = llama_model_get_vocab(model_); + const llama_vocab* vocab = llama_model_get_vocab(model_.get()); if (vocab == nullptr) { throw std::runtime_error("LlamaGenerator: vocab unavailable"); } @@ -46,7 +46,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, /** * Clear KV cache to ensure clean inference state (no residual context) */ - llama_memory_clear(llama_get_memory(context_), true); + llama_memory_clear(llama_get_memory(context_.get()), true); /** * TOKENIZATION PHASE @@ -79,8 +79,8 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, * Validate and compute effective token budgets based on context window * constraints */ - const auto n_ctx = static_cast(llama_n_ctx(context_)); - const auto n_batch = static_cast(llama_n_batch(context_)); + const auto n_ctx = static_cast(llama_n_ctx(context_.get())); + const auto n_batch = static_cast(llama_n_batch(context_.get())); if (n_ctx <= 1 || n_batch <= 0) { throw std::runtime_error("LlamaGenerator: invalid context or batch size"); } @@ -117,7 +117,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, */ const llama_batch prompt_batch = llama_batch_get_one( prompt_tokens.data(), static_cast(prompt_tokens.size())); - if (llama_decode(context_, prompt_batch) != 0) { + if (llama_decode(context_.get(), prompt_batch) != 0) { throw std::runtime_error("LlamaGenerator: prompt decode failed"); } @@ -139,7 +139,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, * Index -1 means use the last output position from previous batch */ const llama_token next = - llama_sampler_sample(sampler_->chain, context_, -1); + llama_sampler_sample(sampler_->chain, context_.get(), -1); /** * Stop if model predicts end-of-generation token (EOS/EOT) */ @@ -153,7 +153,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, */ llama_token decode_token = next; const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1); - if (llama_decode(context_, one_token_batch) != 0) { + if (llama_decode(context_.get(), one_token_batch) != 0) { throw std::runtime_error( "LlamaGenerator: decode failed during generation"); } diff --git a/pipeline/src/data_generation/llama/llama_generator.cc b/pipeline/src/data_generation/llama/llama_generator.cc index 180f591..7571b4d 100644 --- a/pipeline/src/data_generation/llama/llama_generator.cc +++ b/pipeline/src/data_generation/llama/llama_generator.cc @@ -24,6 +24,18 @@ struct SamplerConfig { using SamplerPtr = std::unique_ptr; +void LlamaGenerator::ModelDeleter::operator()(llama_model* model) const noexcept { + if (model != nullptr) { + llama_model_free(model); + } +} + +void LlamaGenerator::ContextDeleter::operator()(llama_context* context) const noexcept { + if (context != nullptr) { + llama_free(context); + } +} + static SamplerPtr CreateSamplerChain(const SamplerConfig& config, std::mt19937& rng) { const llama_sampler_chain_params sampler_params = @@ -88,6 +100,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, sampling_temperature_ = options.temperature; sampling_top_p_ = options.top_p; sampling_top_k_ = options.top_k; + if (options.seed == -1) { std::random_device random_device; rng_.seed(random_device()); @@ -100,26 +113,8 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_, sampling_top_k_}; auto sampler_chain = CreateSamplerChain(sampler_config, rng_); - sampler_.reset(new SamplerState()); + sampler_ = std::make_unique(); sampler_->chain = sampler_chain.release(); } -LlamaGenerator::~LlamaGenerator() { - sampler_.reset(); - - /** - * Free the inference context (contains KV cache and computation state) - */ - if (context_ != nullptr) { - llama_free(context_); - context_ = nullptr; - } - - /** - * Free the loaded model (contains weights and vocabulary) - */ - if (model_ != nullptr) { - llama_model_free(model_); - model_ = nullptr; - } -} +LlamaGenerator::~LlamaGenerator() = default; diff --git a/pipeline/src/data_generation/llama/load.cc b/pipeline/src/data_generation/llama/load.cc index cb7357e..98feb5a 100644 --- a/pipeline/src/data_generation/llama/load.cc +++ b/pipeline/src/data_generation/llama/load.cc @@ -9,23 +9,19 @@ #include #include #include +#include #include "data_generation/llama_generator.h" #include "llama.h" void LlamaGenerator::Load(const std::string& model_path) { - if (context_ != nullptr) { - llama_free(context_); - context_ = nullptr; - } - if (model_ != nullptr) { - llama_model_free(model_); - model_ = nullptr; - } + context_.reset(); + model_.reset(); const llama_model_params model_params = llama_model_default_params(); - model_ = llama_model_load_from_file(model_path.c_str(), model_params); - if (model_ == nullptr) { + LlamaGenerator::ModelHandle loaded_model( + llama_model_load_from_file(model_path.c_str(), model_params)); + if (!loaded_model) { throw std::runtime_error( "LlamaGenerator: failed to load model from path: " + model_path); } @@ -34,12 +30,14 @@ void LlamaGenerator::Load(const std::string& model_path) { context_params.n_ctx = n_ctx_; context_params.n_batch = std::min(n_ctx_, static_cast(5000)); - context_ = llama_init_from_model(model_, context_params); - if (context_ == nullptr) { - llama_model_free(model_); - model_ = nullptr; + LlamaGenerator::ContextHandle loaded_context( + llama_init_from_model(loaded_model.get(), context_params)); + if (!loaded_context) { throw std::runtime_error("LlamaGenerator: failed to create context"); } + model_ = std::move(loaded_model); + context_ = std::move(loaded_context); + spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); }