mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-05-31 17:53:59 +00:00
Use unique_ptr with custom deleter for llama
This commit is contained in:
@@ -7,6 +7,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
@@ -65,6 +66,17 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
static constexpr uint32_t kDefaultSamplingTopK = 64;
|
static constexpr uint32_t kDefaultSamplingTopK = 64;
|
||||||
static constexpr uint32_t kDefaultContextSize = 8192;
|
static constexpr uint32_t kDefaultContextSize = 8192;
|
||||||
|
|
||||||
|
struct ModelDeleter {
|
||||||
|
void operator()(llama_model* model) const noexcept;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ContextDeleter {
|
||||||
|
void operator()(llama_context* context) const noexcept;
|
||||||
|
};
|
||||||
|
|
||||||
|
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
||||||
|
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
||||||
|
|
||||||
struct SamplerState {
|
struct SamplerState {
|
||||||
SamplerState() = default;
|
SamplerState() = default;
|
||||||
~SamplerState();
|
~SamplerState();
|
||||||
@@ -116,8 +128,8 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
*/
|
*/
|
||||||
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
|
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
|
||||||
|
|
||||||
llama_model* model_ = nullptr;
|
ModelHandle model_;
|
||||||
llama_context* context_ = nullptr;
|
ContextHandle context_;
|
||||||
/// @brief Persistent sampler chain reused across inference calls.
|
/// @brief Persistent sampler chain reused across inference calls.
|
||||||
std::unique_ptr<SamplerState> sampler_;
|
std::unique_ptr<SamplerState> sampler_;
|
||||||
float sampling_temperature_ = 1.0F;
|
float sampling_temperature_ = 1.0F;
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ static constexpr std::size_t kPromptTokenSlack = 8;
|
|||||||
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
||||||
const std::string& prompt,
|
const std::string& prompt,
|
||||||
const int max_tokens) {
|
const int max_tokens) {
|
||||||
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
|
return InferFormatted(ToChatPromptPublic(model_.get(), system_prompt, prompt),
|
||||||
max_tokens);
|
max_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,14 +31,14 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
/**
|
/**
|
||||||
* Validate that model and context are loaded
|
* Validate that model and context are loaded
|
||||||
*/
|
*/
|
||||||
if (model_ == nullptr || context_ == nullptr) {
|
if (!model_ || !context_) {
|
||||||
throw std::runtime_error("LlamaGenerator: model not loaded");
|
throw std::runtime_error("LlamaGenerator: model not loaded");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get vocabulary for tokenization and token-to-text conversion
|
* Get vocabulary for tokenization and token-to-text conversion
|
||||||
*/
|
*/
|
||||||
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
const llama_vocab* vocab = llama_model_get_vocab(model_.get());
|
||||||
if (vocab == nullptr) {
|
if (vocab == nullptr) {
|
||||||
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
/**
|
/**
|
||||||
* Clear KV cache to ensure clean inference state (no residual context)
|
* Clear KV cache to ensure clean inference state (no residual context)
|
||||||
*/
|
*/
|
||||||
llama_memory_clear(llama_get_memory(context_), true);
|
llama_memory_clear(llama_get_memory(context_.get()), true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TOKENIZATION PHASE
|
* TOKENIZATION PHASE
|
||||||
@@ -79,8 +79,8 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
* Validate and compute effective token budgets based on context window
|
* Validate and compute effective token budgets based on context window
|
||||||
* constraints
|
* constraints
|
||||||
*/
|
*/
|
||||||
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
|
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_.get()));
|
||||||
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_));
|
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_.get()));
|
||||||
if (n_ctx <= 1 || n_batch <= 0) {
|
if (n_ctx <= 1 || n_batch <= 0) {
|
||||||
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
||||||
}
|
}
|
||||||
@@ -117,7 +117,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
*/
|
*/
|
||||||
const llama_batch prompt_batch = llama_batch_get_one(
|
const llama_batch prompt_batch = llama_batch_get_one(
|
||||||
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
|
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
|
||||||
if (llama_decode(context_, prompt_batch) != 0) {
|
if (llama_decode(context_.get(), prompt_batch) != 0) {
|
||||||
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +139,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
* Index -1 means use the last output position from previous batch
|
* Index -1 means use the last output position from previous batch
|
||||||
*/
|
*/
|
||||||
const llama_token next =
|
const llama_token next =
|
||||||
llama_sampler_sample(sampler_->chain, context_, -1);
|
llama_sampler_sample(sampler_->chain, context_.get(), -1);
|
||||||
/**
|
/**
|
||||||
* Stop if model predicts end-of-generation token (EOS/EOT)
|
* Stop if model predicts end-of-generation token (EOS/EOT)
|
||||||
*/
|
*/
|
||||||
@@ -153,7 +153,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
*/
|
*/
|
||||||
llama_token decode_token = next;
|
llama_token decode_token = next;
|
||||||
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
|
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
|
||||||
if (llama_decode(context_, one_token_batch) != 0) {
|
if (llama_decode(context_.get(), one_token_batch) != 0) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: decode failed during generation");
|
"LlamaGenerator: decode failed during generation");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,18 @@ struct SamplerConfig {
|
|||||||
using SamplerPtr =
|
using SamplerPtr =
|
||||||
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
||||||
|
|
||||||
|
void LlamaGenerator::ModelDeleter::operator()(llama_model* model) const noexcept {
|
||||||
|
if (model != nullptr) {
|
||||||
|
llama_model_free(model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LlamaGenerator::ContextDeleter::operator()(llama_context* context) const noexcept {
|
||||||
|
if (context != nullptr) {
|
||||||
|
llama_free(context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
|
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
|
||||||
std::mt19937& rng) {
|
std::mt19937& rng) {
|
||||||
const llama_sampler_chain_params sampler_params =
|
const llama_sampler_chain_params sampler_params =
|
||||||
@@ -88,6 +100,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
|||||||
sampling_temperature_ = options.temperature;
|
sampling_temperature_ = options.temperature;
|
||||||
sampling_top_p_ = options.top_p;
|
sampling_top_p_ = options.top_p;
|
||||||
sampling_top_k_ = options.top_k;
|
sampling_top_k_ = options.top_k;
|
||||||
|
|
||||||
if (options.seed == -1) {
|
if (options.seed == -1) {
|
||||||
std::random_device random_device;
|
std::random_device random_device;
|
||||||
rng_.seed(random_device());
|
rng_.seed(random_device());
|
||||||
@@ -100,26 +113,8 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
|||||||
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
|
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
|
||||||
sampling_top_k_};
|
sampling_top_k_};
|
||||||
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
|
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
|
||||||
sampler_.reset(new SamplerState());
|
sampler_ = std::make_unique<SamplerState>();
|
||||||
sampler_->chain = sampler_chain.release();
|
sampler_->chain = sampler_chain.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
LlamaGenerator::~LlamaGenerator() {
|
LlamaGenerator::~LlamaGenerator() = default;
|
||||||
sampler_.reset();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Free the inference context (contains KV cache and computation state)
|
|
||||||
*/
|
|
||||||
if (context_ != nullptr) {
|
|
||||||
llama_free(context_);
|
|
||||||
context_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Free the loaded model (contains weights and vocabulary)
|
|
||||||
*/
|
|
||||||
if (model_ != nullptr) {
|
|
||||||
llama_model_free(model_);
|
|
||||||
model_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,23 +9,19 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
void LlamaGenerator::Load(const std::string& model_path) {
|
void LlamaGenerator::Load(const std::string& model_path) {
|
||||||
if (context_ != nullptr) {
|
context_.reset();
|
||||||
llama_free(context_);
|
model_.reset();
|
||||||
context_ = nullptr;
|
|
||||||
}
|
|
||||||
if (model_ != nullptr) {
|
|
||||||
llama_model_free(model_);
|
|
||||||
model_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const llama_model_params model_params = llama_model_default_params();
|
const llama_model_params model_params = llama_model_default_params();
|
||||||
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
|
LlamaGenerator::ModelHandle loaded_model(
|
||||||
if (model_ == nullptr) {
|
llama_model_load_from_file(model_path.c_str(), model_params));
|
||||||
|
if (!loaded_model) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: failed to load model from path: " + model_path);
|
"LlamaGenerator: failed to load model from path: " + model_path);
|
||||||
}
|
}
|
||||||
@@ -34,12 +30,14 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
|||||||
context_params.n_ctx = n_ctx_;
|
context_params.n_ctx = n_ctx_;
|
||||||
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
|
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
|
||||||
|
|
||||||
context_ = llama_init_from_model(model_, context_params);
|
LlamaGenerator::ContextHandle loaded_context(
|
||||||
if (context_ == nullptr) {
|
llama_init_from_model(loaded_model.get(), context_params));
|
||||||
llama_model_free(model_);
|
if (!loaded_context) {
|
||||||
model_ = nullptr;
|
|
||||||
throw std::runtime_error("LlamaGenerator: failed to create context");
|
throw std::runtime_error("LlamaGenerator: failed to create context");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model_ = std::move(loaded_model);
|
||||||
|
context_ = std::move(loaded_context);
|
||||||
|
|
||||||
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
|
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user