mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 10:04:00 +00:00
Compare commits
2 Commits
3c70c46957
...
15853c62fd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15853c62fd | ||
|
|
ff4b7f2578 |
@@ -7,6 +7,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
@@ -65,6 +66,17 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
static constexpr uint32_t kDefaultSamplingTopK = 64;
|
static constexpr uint32_t kDefaultSamplingTopK = 64;
|
||||||
static constexpr uint32_t kDefaultContextSize = 8192;
|
static constexpr uint32_t kDefaultContextSize = 8192;
|
||||||
|
|
||||||
|
struct ModelDeleter {
|
||||||
|
void operator()(llama_model* model) const noexcept;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ContextDeleter {
|
||||||
|
void operator()(llama_context* context) const noexcept;
|
||||||
|
};
|
||||||
|
|
||||||
|
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
||||||
|
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
||||||
|
|
||||||
struct SamplerState {
|
struct SamplerState {
|
||||||
SamplerState() = default;
|
SamplerState() = default;
|
||||||
~SamplerState();
|
~SamplerState();
|
||||||
@@ -116,8 +128,8 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
*/
|
*/
|
||||||
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
|
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
|
||||||
|
|
||||||
llama_model* model_ = nullptr;
|
ModelHandle model_;
|
||||||
llama_context* context_ = nullptr;
|
ContextHandle context_;
|
||||||
/// @brief Persistent sampler chain reused across inference calls.
|
/// @brief Persistent sampler chain reused across inference calls.
|
||||||
std::unique_ptr<SamplerState> sampler_;
|
std::unique_ptr<SamplerState> sampler_;
|
||||||
float sampling_temperature_ = 1.0F;
|
float sampling_temperature_ = 1.0F;
|
||||||
|
|||||||
@@ -3,26 +3,28 @@
|
|||||||
* @brief BiergartenDataGenerator::Run() implementation.
|
* @brief BiergartenDataGenerator::Run() implementation.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
#include "biergarten_data_generator.h"
|
#include "biergarten_data_generator.h"
|
||||||
|
|
||||||
bool BiergartenDataGenerator::Run() {
|
bool BiergartenDataGenerator::Run() {
|
||||||
try {
|
try {
|
||||||
const std::vector<Location> cities = QueryCitiesWithCountries();
|
std::vector<Location> cities = QueryCitiesWithCountries();
|
||||||
std::vector<EnrichedCity> enriched;
|
std::vector<EnrichedCity> enriched;
|
||||||
enriched.reserve(cities.size());
|
enriched.reserve(cities.size());
|
||||||
|
|
||||||
size_t skipped_count = 0;
|
size_t skipped_count = 0;
|
||||||
for (const auto& city : cities) {
|
for (auto& city : cities) {
|
||||||
try {
|
try {
|
||||||
const std::string region_context =
|
std::string region_context = context_service_->GetLocationContext(city);
|
||||||
context_service_->GetLocationContext(city);
|
|
||||||
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
|
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
|
||||||
city.city, city.country, region_context);
|
city.city, city.country, region_context);
|
||||||
|
|
||||||
enriched.push_back(
|
enriched.push_back(
|
||||||
EnrichedCity{.location = city, .region_context = region_context});
|
EnrichedCity{.location = std::move(city),
|
||||||
|
.region_context = std::move(region_context)});
|
||||||
} catch (const std::exception& exception) {
|
} catch (const std::exception& exception) {
|
||||||
++skipped_count;
|
++skipped_count;
|
||||||
spdlog::warn(
|
spdlog::warn(
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ static constexpr std::size_t kPromptTokenSlack = 8;
|
|||||||
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
||||||
const std::string& prompt,
|
const std::string& prompt,
|
||||||
const int max_tokens) {
|
const int max_tokens) {
|
||||||
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
|
return InferFormatted(ToChatPromptPublic(model_.get(), system_prompt, prompt),
|
||||||
max_tokens);
|
max_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,14 +31,14 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
/**
|
/**
|
||||||
* Validate that model and context are loaded
|
* Validate that model and context are loaded
|
||||||
*/
|
*/
|
||||||
if (model_ == nullptr || context_ == nullptr) {
|
if (!model_ || !context_) {
|
||||||
throw std::runtime_error("LlamaGenerator: model not loaded");
|
throw std::runtime_error("LlamaGenerator: model not loaded");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get vocabulary for tokenization and token-to-text conversion
|
* Get vocabulary for tokenization and token-to-text conversion
|
||||||
*/
|
*/
|
||||||
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
const llama_vocab* vocab = llama_model_get_vocab(model_.get());
|
||||||
if (vocab == nullptr) {
|
if (vocab == nullptr) {
|
||||||
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
/**
|
/**
|
||||||
* Clear KV cache to ensure clean inference state (no residual context)
|
* Clear KV cache to ensure clean inference state (no residual context)
|
||||||
*/
|
*/
|
||||||
llama_memory_clear(llama_get_memory(context_), true);
|
llama_memory_clear(llama_get_memory(context_.get()), true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TOKENIZATION PHASE
|
* TOKENIZATION PHASE
|
||||||
@@ -79,8 +79,8 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
* Validate and compute effective token budgets based on context window
|
* Validate and compute effective token budgets based on context window
|
||||||
* constraints
|
* constraints
|
||||||
*/
|
*/
|
||||||
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
|
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_.get()));
|
||||||
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_));
|
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_.get()));
|
||||||
if (n_ctx <= 1 || n_batch <= 0) {
|
if (n_ctx <= 1 || n_batch <= 0) {
|
||||||
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
||||||
}
|
}
|
||||||
@@ -117,7 +117,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
*/
|
*/
|
||||||
const llama_batch prompt_batch = llama_batch_get_one(
|
const llama_batch prompt_batch = llama_batch_get_one(
|
||||||
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
|
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
|
||||||
if (llama_decode(context_, prompt_batch) != 0) {
|
if (llama_decode(context_.get(), prompt_batch) != 0) {
|
||||||
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +139,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
* Index -1 means use the last output position from previous batch
|
* Index -1 means use the last output position from previous batch
|
||||||
*/
|
*/
|
||||||
const llama_token next =
|
const llama_token next =
|
||||||
llama_sampler_sample(sampler_->chain, context_, -1);
|
llama_sampler_sample(sampler_->chain, context_.get(), -1);
|
||||||
/**
|
/**
|
||||||
* Stop if model predicts end-of-generation token (EOS/EOT)
|
* Stop if model predicts end-of-generation token (EOS/EOT)
|
||||||
*/
|
*/
|
||||||
@@ -153,7 +153,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
*/
|
*/
|
||||||
llama_token decode_token = next;
|
llama_token decode_token = next;
|
||||||
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
|
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
|
||||||
if (llama_decode(context_, one_token_batch) != 0) {
|
if (llama_decode(context_.get(), one_token_batch) != 0) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: decode failed during generation");
|
"LlamaGenerator: decode failed during generation");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,18 @@ struct SamplerConfig {
|
|||||||
using SamplerPtr =
|
using SamplerPtr =
|
||||||
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
||||||
|
|
||||||
|
void LlamaGenerator::ModelDeleter::operator()(llama_model* model) const noexcept {
|
||||||
|
if (model != nullptr) {
|
||||||
|
llama_model_free(model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LlamaGenerator::ContextDeleter::operator()(llama_context* context) const noexcept {
|
||||||
|
if (context != nullptr) {
|
||||||
|
llama_free(context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
|
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
|
||||||
std::mt19937& rng) {
|
std::mt19937& rng) {
|
||||||
const llama_sampler_chain_params sampler_params =
|
const llama_sampler_chain_params sampler_params =
|
||||||
@@ -88,6 +100,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
|||||||
sampling_temperature_ = options.temperature;
|
sampling_temperature_ = options.temperature;
|
||||||
sampling_top_p_ = options.top_p;
|
sampling_top_p_ = options.top_p;
|
||||||
sampling_top_k_ = options.top_k;
|
sampling_top_k_ = options.top_k;
|
||||||
|
|
||||||
if (options.seed == -1) {
|
if (options.seed == -1) {
|
||||||
std::random_device random_device;
|
std::random_device random_device;
|
||||||
rng_.seed(random_device());
|
rng_.seed(random_device());
|
||||||
@@ -100,26 +113,8 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
|||||||
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
|
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
|
||||||
sampling_top_k_};
|
sampling_top_k_};
|
||||||
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
|
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
|
||||||
sampler_.reset(new SamplerState());
|
sampler_ = std::make_unique<SamplerState>();
|
||||||
sampler_->chain = sampler_chain.release();
|
sampler_->chain = sampler_chain.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
LlamaGenerator::~LlamaGenerator() {
|
LlamaGenerator::~LlamaGenerator() = default;
|
||||||
sampler_.reset();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Free the inference context (contains KV cache and computation state)
|
|
||||||
*/
|
|
||||||
if (context_ != nullptr) {
|
|
||||||
llama_free(context_);
|
|
||||||
context_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Free the loaded model (contains weights and vocabulary)
|
|
||||||
*/
|
|
||||||
if (model_ != nullptr) {
|
|
||||||
llama_model_free(model_);
|
|
||||||
model_ = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,23 +9,19 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
void LlamaGenerator::Load(const std::string& model_path) {
|
void LlamaGenerator::Load(const std::string& model_path) {
|
||||||
if (context_ != nullptr) {
|
context_.reset();
|
||||||
llama_free(context_);
|
model_.reset();
|
||||||
context_ = nullptr;
|
|
||||||
}
|
|
||||||
if (model_ != nullptr) {
|
|
||||||
llama_model_free(model_);
|
|
||||||
model_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const llama_model_params model_params = llama_model_default_params();
|
const llama_model_params model_params = llama_model_default_params();
|
||||||
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
|
LlamaGenerator::ModelHandle loaded_model(
|
||||||
if (model_ == nullptr) {
|
llama_model_load_from_file(model_path.c_str(), model_params));
|
||||||
|
if (!loaded_model) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: failed to load model from path: " + model_path);
|
"LlamaGenerator: failed to load model from path: " + model_path);
|
||||||
}
|
}
|
||||||
@@ -34,12 +30,14 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
|||||||
context_params.n_ctx = n_ctx_;
|
context_params.n_ctx = n_ctx_;
|
||||||
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
|
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
|
||||||
|
|
||||||
context_ = llama_init_from_model(model_, context_params);
|
LlamaGenerator::ContextHandle loaded_context(
|
||||||
if (context_ == nullptr) {
|
llama_init_from_model(loaded_model.get(), context_params));
|
||||||
llama_model_free(model_);
|
if (!loaded_context) {
|
||||||
model_ = nullptr;
|
|
||||||
throw std::runtime_error("LlamaGenerator: failed to create context");
|
throw std::runtime_error("LlamaGenerator: failed to create context");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model_ = std::move(loaded_model);
|
||||||
|
context_ = std::move(loaded_context);
|
||||||
|
|
||||||
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
|
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user