mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Use unique_ptr with custom deleter for llama
This commit is contained in:
@@ -9,23 +9,19 @@
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "data_generation/llama_generator.h"
|
||||
#include "llama.h"
|
||||
|
||||
void LlamaGenerator::Load(const std::string& model_path) {
|
||||
if (context_ != nullptr) {
|
||||
llama_free(context_);
|
||||
context_ = nullptr;
|
||||
}
|
||||
if (model_ != nullptr) {
|
||||
llama_model_free(model_);
|
||||
model_ = nullptr;
|
||||
}
|
||||
context_.reset();
|
||||
model_.reset();
|
||||
|
||||
const llama_model_params model_params = llama_model_default_params();
|
||||
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
|
||||
if (model_ == nullptr) {
|
||||
LlamaGenerator::ModelHandle loaded_model(
|
||||
llama_model_load_from_file(model_path.c_str(), model_params));
|
||||
if (!loaded_model) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: failed to load model from path: " + model_path);
|
||||
}
|
||||
@@ -34,12 +30,14 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
||||
context_params.n_ctx = n_ctx_;
|
||||
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
|
||||
|
||||
context_ = llama_init_from_model(model_, context_params);
|
||||
if (context_ == nullptr) {
|
||||
llama_model_free(model_);
|
||||
model_ = nullptr;
|
||||
LlamaGenerator::ContextHandle loaded_context(
|
||||
llama_init_from_model(loaded_model.get(), context_params));
|
||||
if (!loaded_context) {
|
||||
throw std::runtime_error("LlamaGenerator: failed to create context");
|
||||
}
|
||||
|
||||
model_ = std::move(loaded_model);
|
||||
context_ = std::move(loaded_context);
|
||||
|
||||
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user