diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index da58a4d..ef0189b 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -19,7 +19,6 @@ struct llama_model; struct llama_context; -struct llama_sampler; /** * @brief Data generator implementation backed by llama.cpp. @@ -78,13 +77,9 @@ class LlamaGenerator final : public DataGenerator { struct ContextDeleter { void operator()(llama_context* context) const noexcept; }; - struct SamplerDeleter { - void operator()(llama_sampler* sampler) const noexcept; - }; using ModelHandle = std::unique_ptr; using ContextHandle = std::unique_ptr; - using SamplerChainHandle = std::unique_ptr; /** * @brief Loads model and prepares inference context. @@ -102,20 +97,24 @@ class LlamaGenerator final : public DataGenerator { * @param system_prompt System role prompt. * @param prompt User prompt. * @param max_tokens Maximum tokens to generate. + * @param grammar Optional GBNF grammar constraining generated output. * @return Generated text. */ std::string Infer(const std::string& system_prompt, const std::string& prompt, - int max_tokens = kDefaultMaxTokens); + int max_tokens = kDefaultMaxTokens, + std::string_view grammar = {}); /** * @brief Runs inference on an already-formatted prompt. * * @param formatted_prompt Prompt preformatted for model chat template. * @param max_tokens Maximum tokens to generate. + * @param grammar Optional GBNF grammar constraining generated output. * @return Generated text. */ std::string InferFormatted(const std::string& formatted_prompt, - int max_tokens = kDefaultMaxTokens); + int max_tokens = kDefaultMaxTokens, + std::string_view grammar = {}); /** * @brief Loads the brewery system prompt from disk. @@ -127,8 +126,6 @@ class LlamaGenerator final : public DataGenerator { ModelHandle model_; ContextHandle context_; - /// @brief Persistent sampler chain reused across inference calls. - SamplerChainHandle sampler_; float sampling_temperature_ = 1.0F; float sampling_top_p_ = kDefaultSamplingTopP; uint32_t sampling_top_k_ = kDefaultSamplingTopK; diff --git a/pipeline/includes/data_generation/llama_generator_helpers.h b/pipeline/includes/data_generation/llama_generator_helpers.h index 11fe593..b1e26db 100644 --- a/pipeline/includes/data_generation/llama_generator_helpers.h +++ b/pipeline/includes/data_generation/llama_generator_helpers.h @@ -14,7 +14,7 @@ struct llama_model; struct llama_vocab; -typedef int32_t llama_token; +using llama_token = int32_t; /** * @brief Normalizes and truncates regional context. @@ -60,12 +60,4 @@ std::optional ValidateBreweryJson(const std::string& raw, std::string& name_out, std::string& description_out); -/** - * @brief Extracts the last balanced JSON object from text. - * - * @param text Input text. - * @return Extracted JSON object or an empty string if none exists. - */ -std::string ExtractLastJsonObject(const std::string& text); - #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ diff --git a/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc b/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc index 41a0f98..5cf60b6 100644 --- a/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc +++ b/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc @@ -13,7 +13,7 @@ #include "biergarten_data_generator.h" #include "json_handling/json_loader.h" -static constexpr size_t kBreweryAmount = 4; +static constexpr size_t kBreweryAmount = 50; std::vector BiergartenDataGenerator::QueryCitiesWithCountries() { spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); diff --git a/pipeline/src/data_generation/llama/generate_brewery.cc b/pipeline/src/data_generation/llama/generate_brewery.cc index 5ddc326..f511ed4 100644 --- a/pipeline/src/data_generation/llama/generate_brewery.cc +++ b/pipeline/src/data_generation/llama/generate_brewery.cc @@ -6,56 +6,24 @@ #include "data_generation/llama_generator.h" -#include #include #include #include #include +#include #include #include "data_generation/llama_generator_helpers.h" -static std::string ExtractFinalJsonPayload(std::string raw_response) { - auto trim = [](const std::string_view text) -> std::string_view { - const size_t first = text.find_first_not_of(" \t\n\r"); - if (first == std::string_view::npos) { - return {}; - } - - const size_t last = text.find_last_not_of(" \t\n\r"); - return text.substr(first, last - first + 1); - }; - - static constexpr std::array separator_tokens = { - "<|think|>", "", "<|turn|>", - "", "", "<|channel|>"}; - - size_t separator_pos = std::string::npos; - size_t separator_length = 0; - for (const std::string_view token : separator_tokens) { - const size_t candidate_pos = raw_response.rfind(token); - if (candidate_pos != std::string::npos && - (separator_pos == std::string::npos || candidate_pos > separator_pos)) { - separator_pos = candidate_pos; - separator_length = token.size(); - } - } - - if (separator_pos != std::string::npos) { - raw_response.erase(0, separator_pos + separator_length); - } - - const std::string_view trimmed = trim(raw_response); - const std::string json_candidate = - ExtractLastJsonObject(std::string(trimmed)); - - if (!json_candidate.empty()) { - return json_candidate; - } - - return std::string(trimmed); -} +static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery( +root ::= ws "{" ws "\"name\"" ws ":" ws string ws "," ws "\"description\"" ws ":" ws string ws "}" ws +ws ::= [ \t\n\r]* +string ::= "\"" char+ "\"" +char ::= [^"\\\x7F\x00-\x1F] | [\\] escape +escape ::= ["\\/bfnrt] | "u" hex hex hex hex +hex ::= [0-9a-fA-F] +)json_brewery"; BreweryResult LlamaGenerator::GenerateBrewery( const Location& location, const std::string& region_context) { @@ -108,7 +76,7 @@ BreweryResult LlamaGenerator::GenerateBrewery( for (int attempt = 0; attempt < max_attempts; ++attempt) { constexpr int max_tokens = 1052; // Generate brewery data from LLM - raw = this->Infer(system_prompt, prompt, max_tokens); + raw = this->Infer(system_prompt, prompt, max_tokens, kBreweryJsonGrammar); spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, raw); @@ -116,9 +84,8 @@ BreweryResult LlamaGenerator::GenerateBrewery( std::string name; std::string description; - const std::string json_only = ExtractFinalJsonPayload(raw); const std::optional validation_error = - ValidateBreweryJson(json_only, name, description); + ValidateBreweryJson(raw, name, description); if (!validation_error.has_value()) { // Success: return parsed brewery data return BreweryResult{.name = std::move(name), diff --git a/pipeline/src/data_generation/llama/helpers.cc b/pipeline/src/data_generation/llama/helpers.cc index 098af1d..88bbf5b 100644 --- a/pipeline/src/data_generation/llama/helpers.cc +++ b/pipeline/src/data_generation/llama/helpers.cc @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #include @@ -97,11 +96,16 @@ std::string ToChatPrompt(const llama_model* model, return combined_prompt; } - const std::array messages = { - {{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}}; + const std::array messages = {{ + {.role = "system", .content = system_prompt.c_str()}, + {.role = "user", .content = user_prompt.c_str()}, + }}; + + constexpr std::size_t min_template_buffer_size = 1024; std::vector buffer(std::max( - 1024, (system_prompt.size() + user_prompt.size()) * 4)); + min_template_buffer_size, + (system_prompt.size() + user_prompt.size()) * 4)); auto apply_template_with_resize = [&](const llama_chat_message* chat_messages, int32_t message_count) -> int32_t { @@ -113,11 +117,11 @@ std::string ToChatPrompt(const llama_model* model, return result; } - if (result >= static_cast(buffer.size())) { + const auto buffer_size = static_cast(buffer.size()); + if (result >= buffer_size) { buffer.resize(static_cast(result) + 1); result = llama_chat_apply_template(tmpl, chat_messages, message_count, - true, buffer.data(), - static_cast(buffer.size())); + true, buffer.data(), buffer_size); } return result; @@ -136,8 +140,9 @@ std::string ToChatPrompt(const llama_model* model, // FALLBACK: If the template fails (e.g., Model rejecting the "system" role), // combine the system and user prompts into a single "user" message. - const std::array fallback_msg = { - {{"user", combined_prompt.c_str()}}}; + const std::array fallback_msg = {{ + {.role = "user", .content = combined_prompt.c_str()}, + }}; template_result = apply_template_with_resize(fallback_msg.data(), 1); @@ -188,102 +193,17 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token, "LlamaGenerator: failed to decode sampled token piece"); } -// Shared parser used by the public extractor and JSON validation. -static bool ExtractLastJsonObject(const std::string& text, - std::string& json_out) { - // Remember where the most recent balanced object started. - size_t start = std::string::npos; - - // Track nested braces outside of quoted strings. - int depth = 0; - - // Track whether the scan is currently inside a quoted string. - bool in_string = false; - - // Track escape sequences so quotes inside strings are handled correctly. - bool escaped = false; - - // Record whether at least one complete object was found. - bool found = false; - - // Keep the latest complete object candidate. - std::string candidate; - - // Scan the input text one character at a time. - for (size_t i = 0; i < text.size(); ++i) { - // Inspect the current character. - const char chr = text[i]; - - // Inside a string literal, only escapes and quotes affect state. - if (in_string) { - if (escaped) { - // The current character was escaped, so clear the escape flag. - escaped = false; - } else if (chr == '\\') { - // Mark the next character as escaped. - escaped = true; - } else if (chr == '"') { - // Closing quote ends the string literal. - in_string = false; - } - continue; - } - - // Opening quotes enter string mode. - if (chr == '"') { - in_string = true; - continue; - } - - // Opening braces begin or nest a JSON object. - if (chr == '{') { - if (depth == 0) { - // Record the start of the outermost object. - start = i; - } - - // Increase nesting depth for the active object. - ++depth; - continue; - } - - // Closing braces may complete an object. - if (chr == '}') { - if (depth == 0) { - // Ignore stray closing braces. - continue; - } - - // Drop one level of nesting. - --depth; - if (depth == 0 && start != std::string::npos) { - // Capture the latest complete object seen so far. - candidate = text.substr(start, i - start + 1); - found = true; - } - } - } - - if (!found) { - return false; - } - - // Return the captured object text to the caller. - json_out = std::move(candidate); - return true; -} - std::optional ValidateBreweryJson(const std::string& raw, std::string& name_out, std::string& description_out) { - auto validate_object = [&](const boost::json::value& jv, + auto validate_object = [&](const boost::json::value& json_value, std::string& error_out) -> bool { - if (!jv.is_object()) { + if (!json_value.is_object()) { error_out = "JSON root must be an object"; return false; } - const auto& obj = jv.get_object(); + const auto& obj = json_value.get_object(); if (!obj.contains("name") || !obj.at("name").is_string()) { error_out = "JSON field 'name' is missing or not a string"; return false; @@ -313,14 +233,15 @@ std::optional ValidateBreweryJson(const std::string& raw, std::string name_lower = name_out; std::string description_lower = description_out; - std::transform( - name_lower.begin(), name_lower.end(), name_lower.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); + std::ranges::transform(name_lower, name_lower.begin(), + [](unsigned char character) { + return static_cast(std::tolower(character)); + }); - std::transform(description_lower.begin(), description_lower.end(), - description_lower.begin(), [](unsigned char c) { - return static_cast(std::tolower(c)); - }); + std::ranges::transform(description_lower, description_lower.begin(), + [](unsigned char character) { + return static_cast(std::tolower(character)); + }); if (name_lower == "string" || description_lower == "string") { error_out = "JSON appears to be a schema placeholder, not content"; @@ -331,41 +252,16 @@ std::optional ValidateBreweryJson(const std::string& raw, return true; }; - boost::system::error_code ec; - boost::json::value jv = boost::json::parse(raw, ec); + boost::system::error_code error_code; + boost::json::value json_value = boost::json::parse(raw, error_code); std::string validation_error; - if (ec) { - std::string extracted; - if (!ExtractLastJsonObject(raw, extracted)) { - return "JSON parse error: " + ec.message(); - } - - ec.clear(); - jv = boost::json::parse(extracted, ec); - if (ec) { - return "JSON parse error: " + ec.message(); - } - - if (!validate_object(jv, validation_error)) { - return validation_error; - } - - return std::nullopt; + if (error_code) { + return "JSON parse error: " + error_code.message(); } - if (!validate_object(jv, validation_error)) { + if (!validate_object(json_value, validation_error)) { return validation_error; } return std::nullopt; } - -std::string ExtractLastJsonObject(const std::string& text) { - // Reuse the internal parser and return an empty string if none was found. - std::string extracted; - if (ExtractLastJsonObject(text, extracted)) { - return extracted; - } - - return {}; -} diff --git a/pipeline/src/data_generation/llama/infer.cc b/pipeline/src/data_generation/llama/infer.cc index ef24db2..bc47e13 100644 --- a/pipeline/src/data_generation/llama/infer.cc +++ b/pipeline/src/data_generation/llama/infer.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "data_generation/llama_generator.h" @@ -19,15 +20,68 @@ static constexpr size_t kPromptTokenSlack = 8; +namespace { + +using SamplerHandle = std::unique_ptr; + +struct SamplerConfig { + float temperature; + uint32_t top_k; + float top_p; + uint32_t seed; +}; + +SamplerHandle MakeSamplerChain(const llama_vocab* vocab, + const SamplerConfig& config, + std::string_view grammar) { + const llama_sampler_chain_params sampler_params = + llama_sampler_chain_default_params(); + + SamplerHandle chain(llama_sampler_chain_init(sampler_params), + llama_sampler_free); + if (!chain) { + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + } + + auto add_sampler = [&](llama_sampler* sampler, const char* error_message) { + if (sampler == nullptr) { + throw std::runtime_error(error_message); + } + + llama_sampler_chain_add(chain.get(), sampler); + }; + + if (!grammar.empty()) { + const std::string grammar_text(grammar); + add_sampler(llama_sampler_init_grammar(vocab, grammar_text.c_str(), "root"), + "LlamaGenerator: failed to initialize grammar sampler"); + } + + add_sampler(llama_sampler_init_temp(config.temperature), + "LlamaGenerator: failed to initialize temperature sampler"); + add_sampler(llama_sampler_init_top_k(static_cast(config.top_k)), + "LlamaGenerator: failed to initialize top-k sampler"); + add_sampler(llama_sampler_init_top_p(config.top_p, 1), + "LlamaGenerator: failed to initialize top-p sampler"); + add_sampler(llama_sampler_init_dist(config.seed), + "LlamaGenerator: failed to initialize distribution sampler"); + + return chain; +} + +} // namespace + std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, - const int max_tokens) { + const int max_tokens, + std::string_view grammar) { return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt), - max_tokens); + max_tokens, grammar); } std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, - const int max_tokens) { + const int max_tokens, + std::string_view grammar) { /** * Validate that model and context are loaded */ @@ -43,6 +97,14 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, throw std::runtime_error("LlamaGenerator: vocab unavailable"); } + const SamplerConfig sampler_config{ + .temperature = sampling_temperature_, + .top_k = sampling_top_k_, + .top_p = sampling_top_p_, + .seed = rng_(), + }; + auto sampler = MakeSamplerChain(vocab, sampler_config, grammar); + /** * Clear KV cache to ensure clean inference state (no residual context) */ @@ -140,17 +202,13 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, std::vector generated_tokens; generated_tokens.reserve(static_cast(effective_max_tokens)); - if (!sampler_) { - throw std::runtime_error("LlamaGenerator: sampler not initialized"); - } - for (int i = 0; i < effective_max_tokens; ++i) { /** * Sample next token using configured sampler chain and model logits * Index -1 means use the last output position from previous batch */ const llama_token next = - llama_sampler_sample(sampler_.get(), context_.get(), -1); + llama_sampler_sample(sampler.get(), context_.get(), -1); /** * Stop if model predicts end-of-generation token (EOS/EOT) */ diff --git a/pipeline/src/data_generation/llama/llama_generator.cc b/pipeline/src/data_generation/llama/llama_generator.cc index ccb4a3b..61ddf8a 100644 --- a/pipeline/src/data_generation/llama/llama_generator.cc +++ b/pipeline/src/data_generation/llama/llama_generator.cc @@ -30,13 +30,6 @@ void LlamaGenerator::ContextDeleter::operator()( } } -void LlamaGenerator::SamplerDeleter::operator()( - llama_sampler* sampler) const noexcept { - if (sampler != nullptr) { - llama_sampler_free(sampler); - } -} - LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, const std::string& model_path) : rng_(std::random_device{}()) { @@ -81,25 +74,6 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, n_ctx_ = options.n_ctx; this->Load(model_path); - const llama_sampler_chain_params sampler_params = - llama_sampler_chain_default_params(); - - sampler_ = SamplerChainHandle(llama_sampler_chain_init(sampler_params)); - if (!sampler_) { - throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - } - - llama_sampler_chain_add(sampler_.get(), - llama_sampler_init_temp(sampling_temperature_)); - - llama_sampler_chain_add( - sampler_.get(), - llama_sampler_init_top_k(static_cast(sampling_top_k_))); - - llama_sampler_chain_add(sampler_.get(), - llama_sampler_init_top_p(sampling_top_p_, 1)); - - llama_sampler_chain_add(sampler_.get(), llama_sampler_init_dist(rng_())); } LlamaGenerator::~LlamaGenerator() = default;