diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index e7e9901..da58a4d 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -1,6 +1,8 @@ #ifndef BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_ #define BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_ +#include + /** * @file data_generation/llama_generator.h * @brief llama.cpp-backed implementation of DataGenerator. @@ -34,12 +36,16 @@ class LlamaGenerator final : public DataGenerator { LlamaGenerator(const ApplicationOptions& options, const std::string& model_path); - /// @brief Releases model/context resources. ~LlamaGenerator() override; + // disable copy constructor LlamaGenerator(const LlamaGenerator&) = delete; + + // disable copy assignment operator LlamaGenerator& operator=(const LlamaGenerator&) = delete; + // disable move constructor LlamaGenerator(LlamaGenerator&&) = delete; + // disable move assignment operator LlamaGenerator& operator=(LlamaGenerator&&) = delete; /** @@ -61,7 +67,7 @@ class LlamaGenerator final : public DataGenerator { UserResult GenerateUser(const std::string& locale) override; private: - static constexpr int kDefaultMaxTokens = 10000; + static constexpr int32_t kDefaultMaxTokens = 10000; static constexpr float kDefaultSamplingTopP = 0.95F; static constexpr uint32_t kDefaultSamplingTopK = 64; static constexpr uint32_t kDefaultContextSize = 8192; @@ -69,25 +75,16 @@ class LlamaGenerator final : public DataGenerator { struct ModelDeleter { void operator()(llama_model* model) const noexcept; }; - struct ContextDeleter { void operator()(llama_context* context) const noexcept; }; + struct SamplerDeleter { + void operator()(llama_sampler* sampler) const noexcept; + }; using ModelHandle = std::unique_ptr; using ContextHandle = std::unique_ptr; - - struct SamplerState { - SamplerState() = default; - ~SamplerState(); - - SamplerState(const SamplerState&) = delete; - SamplerState& operator=(const SamplerState&) = delete; - SamplerState(SamplerState&&) = delete; - SamplerState& operator=(SamplerState&&) = delete; - - llama_sampler* chain = nullptr; - }; + using SamplerChainHandle = std::unique_ptr; /** * @brief Loads model and prepares inference context. @@ -126,12 +123,12 @@ class LlamaGenerator final : public DataGenerator { * @param prompt_file_path Prompt file path to try first. * @return Loaded prompt text. */ - std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path); + std::string LoadBrewerySystemPrompt(const std::filesystem::path& prompt_file_path); ModelHandle model_; ContextHandle context_; /// @brief Persistent sampler chain reused across inference calls. - std::unique_ptr sampler_; + SamplerChainHandle sampler_; float sampling_temperature_ = 1.0F; float sampling_top_p_ = kDefaultSamplingTopP; uint32_t sampling_top_k_ = kDefaultSamplingTopK; diff --git a/pipeline/includes/data_generation/llama_generator_helpers.h b/pipeline/includes/data_generation/llama_generator_helpers.h index bacdc64..11fe593 100644 --- a/pipeline/includes/data_generation/llama_generator_helpers.h +++ b/pipeline/includes/data_generation/llama_generator_helpers.h @@ -7,14 +7,14 @@ */ #include +#include #include #include #include -#include struct llama_model; struct llama_vocab; -typedef int llama_token; +typedef int32_t llama_token; /** * @brief Normalizes and truncates regional context. @@ -23,18 +23,8 @@ typedef int llama_token; * @param max_chars Maximum output length. * @return Processed region context. */ -std::string PrepareRegionContextPublic(std::string_view region_context, - std::size_t max_chars = 2000); - -/** - * @brief Parses a response expected to contain two logical lines. - * - * @param raw Raw model output. - * @param error_message Error message thrown on parse failure. - * @return Pair containing first and second parsed fields. - */ -std::pair ParseTwoLineResponsePublic( - const std::string& raw, const std::string& error_message); +std::string PrepareRegionContext(std::string_view region_context, + size_t max_chars = 2000); /** * @brief Applies model chat template to system and user prompts. @@ -44,9 +34,9 @@ std::pair ParseTwoLineResponsePublic( * @param user_prompt User prompt text. * @return Model-formatted prompt. */ -std::string ToChatPromptPublic(const llama_model* model, - const std::string& system_prompt, - const std::string& user_prompt); +std::string ToChatPrompt(const llama_model* model, + const std::string& system_prompt, + const std::string& user_prompt); /** * @brief Decodes a sampled token and appends it to output text. @@ -55,8 +45,8 @@ std::string ToChatPromptPublic(const llama_model* model, * @param token Sampled token id. * @param output Output text buffer. */ -void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token, - std::string& output); +void AppendTokenPiece(const llama_vocab* vocab, llama_token token, + std::string& output); /** * @brief Validates and parses brewery JSON output. @@ -66,9 +56,9 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token, * @param description_out Parsed brewery description. * @return Validation error message if invalid, or std::nullopt on success. */ -std::optional ValidateBreweryJsonPublic( - const std::string& raw, std::string& name_out, - std::string& description_out); +std::optional ValidateBreweryJson(const std::string& raw, + std::string& name_out, + std::string& description_out); /** * @brief Extracts the last balanced JSON object from text. @@ -76,6 +66,6 @@ std::optional ValidateBreweryJsonPublic( * @param text Input text. * @return Extracted JSON object or an empty string if none exists. */ -std::string ExtractLastJsonObjectPublic(const std::string& text); +std::string ExtractLastJsonObject(const std::string& text); -#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ \ No newline at end of file +#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ diff --git a/pipeline/includes/data_generation/mock_generator.h b/pipeline/includes/data_generation/mock_generator.h index 0ca154b..d91c1e3 100644 --- a/pipeline/includes/data_generation/mock_generator.h +++ b/pipeline/includes/data_generation/mock_generator.h @@ -42,7 +42,7 @@ class MockGenerator final : public DataGenerator { * @param location City and country names. * @return Deterministic hash value. */ - static std::size_t DeterministicHash(const Location& location); + static size_t DeterministicHash(const Location& location); static constexpr std::array kBreweryAdjectives = { "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", diff --git a/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc b/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc index ec1b8b2..41a0f98 100644 --- a/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc +++ b/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc @@ -3,18 +3,17 @@ * @brief BiergartenDataGenerator::QueryCitiesWithCountries() implementation. */ -#include "biergarten_data_generator.h" +#include #include #include #include #include -#include - +#include "biergarten_data_generator.h" #include "json_handling/json_loader.h" -static constexpr std::size_t kBreweryAmount = 4; +static constexpr size_t kBreweryAmount = 4; std::vector BiergartenDataGenerator::QueryCitiesWithCountries() { spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); @@ -24,11 +23,12 @@ std::vector BiergartenDataGenerator::QueryCitiesWithCountries() { auto all_locations = JsonLoader::LoadLocations(locations_path); spdlog::info(" Locations available: {}", all_locations.size()); - const std::size_t sample_count = - std::min(kBreweryAmount, all_locations.size()); + const size_t sample_count = std::min(kBreweryAmount, all_locations.size()); + const auto sample_count_signed = static_cast>( sample_count); + std::vector sampled_locations; sampled_locations.reserve(sample_count); diff --git a/pipeline/src/data_generation/llama/generate_brewery.cc b/pipeline/src/data_generation/llama/generate_brewery.cc index 0cd8795..5ddc326 100644 --- a/pipeline/src/data_generation/llama/generate_brewery.cc +++ b/pipeline/src/data_generation/llama/generate_brewery.cc @@ -18,12 +18,12 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) { auto trim = [](const std::string_view text) -> std::string_view { - const std::size_t first = text.find_first_not_of(" \t\n\r"); + const size_t first = text.find_first_not_of(" \t\n\r"); if (first == std::string_view::npos) { return {}; } - const std::size_t last = text.find_last_not_of(" \t\n\r"); + const size_t last = text.find_last_not_of(" \t\n\r"); return text.substr(first, last - first + 1); }; @@ -31,10 +31,10 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) { "<|think|>", "", "<|turn|>", "", "", "<|channel|>"}; - std::size_t separator_pos = std::string::npos; - std::size_t separator_length = 0; + size_t separator_pos = std::string::npos; + size_t separator_length = 0; for (const std::string_view token : separator_tokens) { - const std::size_t candidate_pos = raw_response.rfind(token); + const size_t candidate_pos = raw_response.rfind(token); if (candidate_pos != std::string::npos && (separator_pos == std::string::npos || candidate_pos > separator_pos)) { separator_pos = candidate_pos; @@ -48,10 +48,10 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) { const std::string_view trimmed = trim(raw_response); const std::string json_candidate = - ExtractLastJsonObjectPublic(std::string(trimmed)); + ExtractLastJsonObject(std::string(trimmed)); if (!json_candidate.empty()) { - return ExtractLastJsonObjectPublic(std::string(trimmed)); + return json_candidate; } return std::string(trimmed); @@ -63,7 +63,7 @@ BreweryResult LlamaGenerator::GenerateBrewery( * Preprocess and truncate region context to manageable size */ const std::string safe_region_context = - PrepareRegionContextPublic(region_context); + PrepareRegionContext(region_context); const std::string country_suffix = location.country.empty() ? std::string{} @@ -118,7 +118,7 @@ BreweryResult LlamaGenerator::GenerateBrewery( std::string description; const std::string json_only = ExtractFinalJsonPayload(raw); const std::optional validation_error = - ValidateBreweryJsonPublic(json_only, name, description); + ValidateBreweryJson(json_only, name, description); if (!validation_error.has_value()) { // Success: return parsed brewery data return BreweryResult{.name = std::move(name), diff --git a/pipeline/src/data_generation/llama/helpers.cc b/pipeline/src/data_generation/llama/helpers.cc index 0cadf58..098af1d 100644 --- a/pipeline/src/data_generation/llama/helpers.cc +++ b/pipeline/src/data_generation/llama/helpers.cc @@ -17,7 +17,7 @@ #include #include -#include "data_generation/llama_generator.h" +#include "data_generation/llama_generator_helpers.h" #include "llama.h" /** @@ -25,12 +25,12 @@ */ static std::string Trim(std::string_view value) { constexpr std::string_view whitespace = " \t\n\r\f\v"; - const std::size_t first_index = value.find_first_not_of(whitespace); + const size_t first_index = value.find_first_not_of(whitespace); if (first_index == std::string_view::npos) { return {}; } - const std::size_t last_index = value.find_last_not_of(whitespace); + const size_t last_index = value.find_last_not_of(whitespace); return std::string(value.substr(first_index, last_index - first_index + 1)); } @@ -43,7 +43,7 @@ static std::string CondenseWhitespace(std::string_view text) { out.reserve(text.size()); bool pending_space = false; - for (const unsigned char chr : text) { + for (const char chr : text) { if (std::isspace(chr) != 0) { if (!out.empty()) { pending_space = true; @@ -55,7 +55,7 @@ static std::string CondenseWhitespace(std::string_view text) { out.push_back(' '); pending_space = false; } - out.push_back(static_cast(chr)); + out.push_back(chr); } return out; @@ -65,8 +65,8 @@ static std::string CondenseWhitespace(std::string_view text) { * Truncate region context to fit within max length while preserving word * boundaries */ -static std::string PrepareRegionContext(std::string_view region_context, - const size_t max_chars) { +std::string PrepareRegionContext(std::string_view region_context, + const size_t max_chars) { std::string normalized = CondenseWhitespace(region_context); if (normalized.size() <= max_chars) { return normalized; @@ -82,11 +82,10 @@ static std::string PrepareRegionContext(std::string_view region_context, return normalized; } -static std::string ToChatPrompt(const llama_model* model, - const std::string& system_prompt, - const std::string& user_prompt) { - std::string combined_prompt; - combined_prompt.append(system_prompt); +std::string ToChatPrompt(const llama_model* model, + const std::string& system_prompt, + const std::string& user_prompt) { + std::string combined_prompt = system_prompt; combined_prompt.append("\n\n"); combined_prompt.append(user_prompt); @@ -127,7 +126,7 @@ static std::string ToChatPrompt(const llama_model* model, int32_t template_result = apply_template_with_resize(messages.data(), 2); if (template_result >= 0) { - return {buffer.data(), static_cast(template_result)}; + return {buffer.data(), static_cast(template_result)}; } spdlog::warn( @@ -151,74 +150,114 @@ static std::string ToChatPrompt(const llama_model* model, return combined_prompt; } - return {buffer.data(), static_cast(template_result)}; + return {buffer.data(), static_cast(template_result)}; } -static void AppendTokenPiece(const llama_vocab* vocab, llama_token token, - std::string& output) { - std::array buffer{}; +void AppendTokenPiece(const llama_vocab* vocab, llama_token token, + std::string& output) { + constexpr size_t initial_buffer_size = 256; + + std::array buffer{}; + + // serialize the sampled token into UTF-8 bytes + + auto buffer_too_small = [](int32_t result) -> bool { return result < 0; }; + int32_t bytes = llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true); - if (bytes < 0) { - std::vector dynamic_buffer(static_cast(-bytes)); - bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), - static_cast(dynamic_buffer.size()), 0, - true); - if (bytes < 0) { - throw std::runtime_error( - "LlamaGenerator: failed to decode sampled token piece"); - } - - output.append(dynamic_buffer.data(), static_cast(bytes)); + if (!buffer_too_small(bytes)) { + // Append the decoded bytes from the stack buffer. + output.append(buffer.data(), static_cast(bytes)); return; } - output.append(buffer.data(), static_cast(bytes)); + const int32_t required_size = -bytes; + std::vector dynamic_buffer(static_cast(required_size)); + + // Retry token decoding against the larger heap buffer. + bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), + static_cast(dynamic_buffer.size()), 0, + true); + + if (!buffer_too_small(bytes)) { + output.append(dynamic_buffer.data(), static_cast(bytes)); + } + + throw std::runtime_error( + "LlamaGenerator: failed to decode sampled token piece"); } +// Shared parser used by the public extractor and JSON validation. static bool ExtractLastJsonObject(const std::string& text, std::string& json_out) { - std::size_t start = std::string::npos; + // Remember where the most recent balanced object started. + size_t start = std::string::npos; + + // Track nested braces outside of quoted strings. int depth = 0; + + // Track whether the scan is currently inside a quoted string. bool in_string = false; + + // Track escape sequences so quotes inside strings are handled correctly. bool escaped = false; + + // Record whether at least one complete object was found. bool found = false; + + // Keep the latest complete object candidate. std::string candidate; - for (std::size_t i = 0; i < text.size(); ++i) { - const char ch = text[i]; + // Scan the input text one character at a time. + for (size_t i = 0; i < text.size(); ++i) { + // Inspect the current character. + const char chr = text[i]; + // Inside a string literal, only escapes and quotes affect state. if (in_string) { if (escaped) { + // The current character was escaped, so clear the escape flag. escaped = false; - } else if (ch == '\\') { + } else if (chr == '\\') { + // Mark the next character as escaped. escaped = true; - } else if (ch == '"') { + } else if (chr == '"') { + // Closing quote ends the string literal. in_string = false; } continue; } - if (ch == '"') { + // Opening quotes enter string mode. + if (chr == '"') { in_string = true; continue; } - if (ch == '{') { + // Opening braces begin or nest a JSON object. + if (chr == '{') { if (depth == 0) { + // Record the start of the outermost object. start = i; } + + // Increase nesting depth for the active object. ++depth; continue; } - if (ch == '}') { + // Closing braces may complete an object. + if (chr == '}') { if (depth == 0) { + // Ignore stray closing braces. continue; } + + // Drop one level of nesting. --depth; if (depth == 0 && start != std::string::npos) { + // Capture the latest complete object seen so far. candidate = text.substr(start, i - start + 1); found = true; } @@ -229,22 +268,14 @@ static bool ExtractLastJsonObject(const std::string& text, return false; } + // Return the captured object text to the caller. json_out = std::move(candidate); return true; } -std::string ExtractLastJsonObjectPublic(const std::string& text) { - std::string extracted; - if (ExtractLastJsonObject(text, extracted)) { - return extracted; - } - - return {}; -} - -static std::optional ValidateBreweryJson( - const std::string& raw, std::string& name_out, - std::string& description_out) { +std::optional ValidateBreweryJson(const std::string& raw, + std::string& name_out, + std::string& description_out) { auto validate_object = [&](const boost::json::value& jv, std::string& error_out) -> bool { if (!jv.is_object()) { @@ -281,9 +312,11 @@ static std::optional ValidateBreweryJson( std::string name_lower = name_out; std::string description_lower = description_out; + std::transform( name_lower.begin(), name_lower.end(), name_lower.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); + std::transform(description_lower.begin(), description_lower.end(), description_lower.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); @@ -327,25 +360,12 @@ static std::optional ValidateBreweryJson( return std::nullopt; } -// Forward declarations for helper functions exposed to other translation units -std::string PrepareRegionContextPublic(std::string_view region_context, - std::size_t max_chars) { - return PrepareRegionContext(region_context, max_chars); -} +std::string ExtractLastJsonObject(const std::string& text) { + // Reuse the internal parser and return an empty string if none was found. + std::string extracted; + if (ExtractLastJsonObject(text, extracted)) { + return extracted; + } -std::string ToChatPromptPublic(const llama_model* model, - const std::string& system_prompt, - const std::string& user_prompt) { - return ToChatPrompt(model, system_prompt, user_prompt); -} - -void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token, - std::string& output) { - AppendTokenPiece(vocab, token, output); -} - -std::optional ValidateBreweryJsonPublic( - const std::string& raw, std::string& name_out, - std::string& description_out) { - return ValidateBreweryJson(raw, name_out, description_out); + return {}; } diff --git a/pipeline/src/data_generation/llama/infer.cc b/pipeline/src/data_generation/llama/infer.cc index 77e4787..ef24db2 100644 --- a/pipeline/src/data_generation/llama/infer.cc +++ b/pipeline/src/data_generation/llama/infer.cc @@ -17,12 +17,12 @@ #include "data_generation/llama_generator_helpers.h" #include "llama.h" -static constexpr std::size_t kPromptTokenSlack = 8; +static constexpr size_t kPromptTokenSlack = 8; std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, const int max_tokens) { - return InferFormatted(ToChatPromptPublic(model_.get(), system_prompt, prompt), + return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt), max_tokens); } @@ -54,16 +54,26 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, */ std::vector prompt_tokens(formatted_prompt.size() + kPromptTokenSlack); + + + + int32_t token_count = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), prompt_tokens.data(), - static_cast(prompt_tokens.size()), true, true); + vocab, + formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), + prompt_tokens.data(), + static_cast(prompt_tokens.size()), + true, + true); /** * If buffer too small, negative return indicates required size */ if (token_count < 0) { - prompt_tokens.resize(static_cast(-token_count)); + prompt_tokens.resize(static_cast(-token_count)); + + token_count = llama_tokenize( vocab, formatted_prompt.c_str(), static_cast(formatted_prompt.size()), prompt_tokens.data(), @@ -91,6 +101,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, */ const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); + /** * Prompt can use remaining context after reserving space for generation */ @@ -100,13 +111,13 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, /** * Truncate prompt if necessary to fit within constraints */ - prompt_tokens.resize(static_cast(token_count)); + prompt_tokens.resize(static_cast(token_count)); if (token_count > prompt_budget) { spdlog::warn( "LlamaGenerator: prompt too long ({} tokens), truncating to {} " "tokens to fit n_batch/n_ctx limits", token_count, prompt_budget); - prompt_tokens.resize(static_cast(prompt_budget)); + prompt_tokens.resize(static_cast(prompt_budget)); token_count = prompt_budget; } @@ -127,9 +138,9 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, * end-of-sequence */ std::vector generated_tokens; - generated_tokens.reserve(static_cast(effective_max_tokens)); + generated_tokens.reserve(static_cast(effective_max_tokens)); - if (sampler_ == nullptr || sampler_->chain == nullptr) { + if (!sampler_) { throw std::runtime_error("LlamaGenerator: sampler not initialized"); } @@ -139,7 +150,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, * Index -1 means use the last output position from previous batch */ const llama_token next = - llama_sampler_sample(sampler_->chain, context_.get(), -1); + llama_sampler_sample(sampler_.get(), context_.get(), -1); /** * Stop if model predicts end-of-generation token (EOS/EOT) */ @@ -165,7 +176,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, */ std::string output; for (const llama_token token : generated_tokens) { - AppendTokenPiecePublic(vocab, token, output); + AppendTokenPiece(vocab, token, output); } return output; diff --git a/pipeline/src/data_generation/llama/llama_generator.cc b/pipeline/src/data_generation/llama/llama_generator.cc index 7571b4d..ccb4a3b 100644 --- a/pipeline/src/data_generation/llama/llama_generator.cc +++ b/pipeline/src/data_generation/llama/llama_generator.cc @@ -9,60 +9,31 @@ #include #include #include +#include #include "data_model/application_options.h" #include "llama.h" static constexpr uint32_t kMaxContextSize = 32768U; -struct SamplerConfig { - float temperature; - float top_p; - uint32_t top_k; -}; - -using SamplerPtr = - std::unique_ptr; - -void LlamaGenerator::ModelDeleter::operator()(llama_model* model) const noexcept { +void LlamaGenerator::ModelDeleter::operator()( + llama_model* model) const noexcept { if (model != nullptr) { llama_model_free(model); } } -void LlamaGenerator::ContextDeleter::operator()(llama_context* context) const noexcept { +void LlamaGenerator::ContextDeleter::operator()( + llama_context* context) const noexcept { if (context != nullptr) { llama_free(context); } } -static SamplerPtr CreateSamplerChain(const SamplerConfig& config, - std::mt19937& rng) { - const llama_sampler_chain_params sampler_params = - llama_sampler_chain_default_params(); - - SamplerPtr sampler(llama_sampler_chain_init(sampler_params), - &llama_sampler_free); - if (!sampler) { - throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - } - - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_temp(config.temperature)); - llama_sampler_chain_add( - sampler.get(), - llama_sampler_init_top_k(static_cast(config.top_k))); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_top_p(config.top_p, 1)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng())); - - return sampler; -} - -LlamaGenerator::SamplerState::~SamplerState() { - if (chain != nullptr) { - llama_sampler_free(chain); - chain = nullptr; +void LlamaGenerator::SamplerDeleter::operator()( + llama_sampler* sampler) const noexcept { + if (sampler != nullptr) { + llama_sampler_free(sampler); } } @@ -110,11 +81,25 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, n_ctx_ = options.n_ctx; this->Load(model_path); - const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_, - sampling_top_k_}; - auto sampler_chain = CreateSamplerChain(sampler_config, rng_); - sampler_ = std::make_unique(); - sampler_->chain = sampler_chain.release(); + const llama_sampler_chain_params sampler_params = + llama_sampler_chain_default_params(); + + sampler_ = SamplerChainHandle(llama_sampler_chain_init(sampler_params)); + if (!sampler_) { + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + } + + llama_sampler_chain_add(sampler_.get(), + llama_sampler_init_temp(sampling_temperature_)); + + llama_sampler_chain_add( + sampler_.get(), + llama_sampler_init_top_k(static_cast(sampling_top_k_))); + + llama_sampler_chain_add(sampler_.get(), + llama_sampler_init_top_p(sampling_top_p_, 1)); + + llama_sampler_chain_add(sampler_.get(), llama_sampler_init_dist(rng_())); } LlamaGenerator::~LlamaGenerator() = default; diff --git a/pipeline/src/data_generation/llama/load_brewery_prompt.cc b/pipeline/src/data_generation/llama/load_brewery_prompt.cc index 87e0eea..242eda8 100644 --- a/pipeline/src/data_generation/llama/load_brewery_prompt.cc +++ b/pipeline/src/data_generation/llama/load_brewery_prompt.cc @@ -12,8 +12,6 @@ #include "data_generation/llama_generator.h" -namespace fs = std::filesystem; - /** * @brief Loads brewery system prompt from disk or cache. * @@ -21,22 +19,21 @@ namespace fs = std::filesystem; * @return Prompt text loaded from disk. */ std::string LlamaGenerator::LoadBrewerySystemPrompt( - const std::string& prompt_file_path) { + const std::filesystem::path& prompt_file_path) { // Return cached version if already loaded if (!brewery_system_prompt_.empty()) { return brewery_system_prompt_; } - // Try the provided path only - const fs::path prompt_path(prompt_file_path); - std::ifstream prompt_file(prompt_path); + + std::ifstream prompt_file(prompt_file_path); if (!prompt_file.is_open()) { spdlog::error( "LlamaGenerator: Failed to open brewery system prompt file '{}'", - prompt_path.string()); + prompt_file_path.string()); throw std::runtime_error( "LlamaGenerator: missing brewery system prompt file: " + - prompt_path.string()); + prompt_file_path.string()); } const std::string prompt((std::istreambuf_iterator(prompt_file)), @@ -45,15 +42,15 @@ std::string LlamaGenerator::LoadBrewerySystemPrompt( if (prompt.empty()) { spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty", - prompt_path.string()); + prompt_file_path.string()); throw std::runtime_error( "LlamaGenerator: empty brewery system prompt file: " + - prompt_path.string()); + prompt_file_path.string()); } spdlog::info( "LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)", - prompt_path.string(), prompt.length()); + prompt_file_path.string(), prompt.length()); brewery_system_prompt_ = prompt; return brewery_system_prompt_; } diff --git a/pipeline/src/data_generation/mock/generate_brewery.cc b/pipeline/src/data_generation/mock/generate_brewery.cc index 442694c..c2add93 100644 --- a/pipeline/src/data_generation/mock/generate_brewery.cc +++ b/pipeline/src/data_generation/mock/generate_brewery.cc @@ -12,7 +12,7 @@ BreweryResult MockGenerator::GenerateBrewery( const Location& location, const std::string& /*region_context*/) { - const std::size_t hash = DeterministicHash(location); + const size_t hash = DeterministicHash(location); const std::string_view adjective = kBreweryAdjectives.at(hash % kBreweryAdjectives.size()); diff --git a/pipeline/src/data_generation/mock/generate_user.cc b/pipeline/src/data_generation/mock/generate_user.cc index 0d259c6..51c26d2 100644 --- a/pipeline/src/data_generation/mock/generate_user.cc +++ b/pipeline/src/data_generation/mock/generate_user.cc @@ -11,7 +11,7 @@ #include "data_generation/mock_generator.h" UserResult MockGenerator::GenerateUser(const std::string& locale) { - const std::size_t hash = std::hash{}(locale); + const size_t hash = std::hash{}(locale); UserResult result; const std::string_view username = kUsernames[hash % kUsernames.size()]; diff --git a/pipeline/src/main.cc b/pipeline/src/main.cc index 954a4fc..65699a7 100644 --- a/pipeline/src/main.cc +++ b/pipeline/src/main.cc @@ -4,16 +4,16 @@ * initializes shared infrastructure, and executes the pipeline entry flow. */ +#include + +#include +#include #include #include #include #include #include -#include -#include -#include - #include "biergarten_data_generator.h" #include "data_generation/llama_generator.h" #include "data_generation/mock_generator.h"