diff --git a/pipeline/CMakeLists.txt b/pipeline/CMakeLists.txt index 2527116..3e2ec6d 100644 --- a/pipeline/CMakeLists.txt +++ b/pipeline/CMakeLists.txt @@ -83,7 +83,13 @@ set(PIPELINE_SOURCES src/data_generation/data_downloader.cpp src/database/database.cpp src/json_handling/json_loader.cpp - src/data_generation/llama_generator.cpp + src/data_generation/llama/destructor.cpp + src/data_generation/llama/set_sampling_options.cpp + src/data_generation/llama/load.cpp + src/data_generation/llama/infer.cpp + src/data_generation/llama/generate_brewery.cpp + src/data_generation/llama/generate_user.cpp + src/data_generation/llama/helpers.cpp src/data_generation/mock_generator.cpp src/json_handling/stream_parser.cpp src/wikipedia/wikipedia_service.cpp diff --git a/pipeline/includes/biergarten_data_generator.h b/pipeline/includes/biergarten_data_generator.h index b9e00b6..9371031 100644 --- a/pipeline/includes/biergarten_data_generator.h +++ b/pipeline/includes/biergarten_data_generator.h @@ -106,8 +106,8 @@ private: * @brief Helper struct to store generated brewery data. */ struct GeneratedBrewery { - int cityId; - std::string cityName; + int city_id; + std::string city_name; BreweryResult brewery; }; diff --git a/pipeline/includes/data_generation/llama_generator_helpers.h b/pipeline/includes/data_generation/llama_generator_helpers.h new file mode 100644 index 0000000..11331de --- /dev/null +++ b/pipeline/includes/data_generation/llama_generator_helpers.h @@ -0,0 +1,33 @@ +#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ +#define BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ + +#include +#include + +struct llama_model; +struct llama_vocab; +typedef int llama_token; + +// Helper functions for LlamaGenerator methods +std::string PrepareRegionContextPublic(std::string_view region_context, + std::size_t max_chars = 700); + +std::pair +ParseTwoLineResponsePublic(const std::string& raw, + const std::string& error_message); + +std::string ToChatPromptPublic(const llama_model *model, + const std::string& user_prompt); + +std::string ToChatPromptPublic(const llama_model *model, + const std::string& system_prompt, + const std::string& user_prompt); + +void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, + std::string& output); + +std::string ValidateBreweryJsonPublic(const std::string& raw, + std::string& name_out, + std::string& description_out); + +#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ diff --git a/pipeline/src/biergarten_data_generator.cpp b/pipeline/src/biergarten_data_generator.cpp index 12a37eb..c2d2389 100644 --- a/pipeline/src/biergarten_data_generator.cpp +++ b/pipeline/src/biergarten_data_generator.cpp @@ -111,8 +111,8 @@ void BiergartenDataGenerator::GenerateSampleBreweries() { spdlog::info("\n=== GENERATED DATA DUMP ==="); for (size_t i = 0; i < generatedBreweries_.size(); i++) { const auto &entry = generatedBreweries_[i]; - spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.cityId, - entry.cityName); + spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.city_id, + entry.city_name); spdlog::info(" brewery_name=\"{}\"", entry.brewery.name); spdlog::info(" brewery_description=\"{}\"", entry.brewery.description); } diff --git a/pipeline/src/data_generation/llama/destructor.cpp b/pipeline/src/data_generation/llama/destructor.cpp new file mode 100644 index 0000000..1cdde40 --- /dev/null +++ b/pipeline/src/data_generation/llama/destructor.cpp @@ -0,0 +1,17 @@ +#include "llama.h" + +#include "data_generation/llama_generator.h" + +LlamaGenerator::~LlamaGenerator() { + if (context_ != nullptr) { + llama_free(context_); + context_ = nullptr; + } + + if (model_ != nullptr) { + llama_model_free(model_); + model_ = nullptr; + } + + llama_backend_free(); +} diff --git a/pipeline/src/data_generation/llama/generate_brewery.cpp b/pipeline/src/data_generation/llama/generate_brewery.cpp new file mode 100644 index 0000000..ff0b663 --- /dev/null +++ b/pipeline/src/data_generation/llama/generate_brewery.cpp @@ -0,0 +1,72 @@ +#include +#include + +#include + +#include "data_generation/llama_generator.h" +#include "data_generation/llama_generator_helpers.h" + +BreweryResult +LlamaGenerator::GenerateBrewery(const std::string& city_name, + const std::string& country_name, + const std::string& region_context) { + const std::string safe_region_context = + PrepareRegionContextPublic(region_context); + + const std::string system_prompt = + "You are a copywriter for a craft beer travel guide. " + "Your writing is vivid, specific to place, and avoids generic beer " + "cliches. " + "You must output ONLY valid JSON. " + "The JSON schema must be exactly: {\"name\": \"string\", " + "\"description\": \"string\"}. " + "Do not include markdown formatting or backticks."; + + std::string prompt = + "Write a brewery name and place-specific description for a craft " + "brewery in " + + city_name + + (country_name.empty() ? std::string("") + : std::string(", ") + country_name) + + (safe_region_context.empty() + ? std::string(".") + : std::string(". Regional context: ") + safe_region_context); + + const int max_attempts = 3; + std::string raw; + std::string last_error; + for (int attempt = 0; attempt < max_attempts; ++attempt) { + raw = Infer(system_prompt, prompt, 384); + spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, + raw); + + std::string name; + std::string description; + const std::string validation_error = + ValidateBreweryJsonPublic(raw, name, description); + if (validation_error.empty()) { + return {std::move(name), std::move(description)}; + } + + last_error = validation_error; + spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", + attempt + 1, validation_error); + + prompt = "Your previous response was invalid. Error: " + validation_error + + "\nReturn ONLY valid JSON with this exact schema: " + "{\"name\": \"string\", \"description\": \"string\"}." + "\nDo not include markdown, comments, or extra keys." + "\n\nLocation: " + + city_name + + (country_name.empty() ? std::string("") + : std::string(", ") + country_name) + + (safe_region_context.empty() + ? std::string("") + : std::string("\nRegional context: ") + safe_region_context); + } + + spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: " + "{}", + max_attempts, last_error.empty() ? raw : last_error); + throw std::runtime_error("LlamaGenerator: malformed brewery response"); +} diff --git a/pipeline/src/data_generation/llama/generate_user.cpp b/pipeline/src/data_generation/llama/generate_user.cpp new file mode 100644 index 0000000..4cf8671 --- /dev/null +++ b/pipeline/src/data_generation/llama/generate_user.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +#include + +#include "data_generation/llama_generator.h" +#include "data_generation/llama_generator_helpers.h" + +UserResult LlamaGenerator::GenerateUser(const std::string& locale) { + const std::string system_prompt = + "You generate plausible social media profiles for craft beer " + "enthusiasts. " + "Respond with exactly two lines: " + "the first line is a username (lowercase, no spaces, 8-20 characters), " + "the second line is a one-sentence bio (20-40 words). " + "The profile should feel consistent with the locale. " + "No preamble, no labels."; + + std::string prompt = + "Generate a craft beer enthusiast profile. Locale: " + locale; + + const int max_attempts = 3; + std::string raw; + for (int attempt = 0; attempt < max_attempts; ++attempt) { + raw = Infer(system_prompt, prompt, 128); + spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}", + attempt + 1, raw); + + try { + auto [username, bio] = ParseTwoLineResponsePublic( + raw, "LlamaGenerator: malformed user response"); + + username.erase( + std::remove_if(username.begin(), username.end(), + [](unsigned char ch) { return std::isspace(ch); }), + username.end()); + + if (username.empty() || bio.empty()) { + throw std::runtime_error("LlamaGenerator: malformed user response"); + } + + if (bio.size() > 200) + bio = bio.substr(0, 200); + + return {username, bio}; + } catch (const std::exception &e) { + spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}", + attempt + 1, e.what()); + } + } + + spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}", + max_attempts, raw); + throw std::runtime_error("LlamaGenerator: malformed user response"); +} diff --git a/pipeline/src/data_generation/llama/helpers.cpp b/pipeline/src/data_generation/llama/helpers.cpp new file mode 100644 index 0000000..c1343fc --- /dev/null +++ b/pipeline/src/data_generation/llama/helpers.cpp @@ -0,0 +1,401 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "llama.h" +#include + +#include "data_generation/llama_generator.h" + +namespace { + +std::string Trim(std::string value) { + auto not_space = [](unsigned char ch) { return !std::isspace(ch); }; + + value.erase(value.begin(), + std::find_if(value.begin(), value.end(), not_space)); + value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(), + value.end()); + + return value; +} + +std::string CondenseWhitespace(std::string text) { + std::string out; + out.reserve(text.size()); + + bool in_whitespace = false; + for (unsigned char ch : text) { + if (std::isspace(ch)) { + if (!in_whitespace) { + out.push_back(' '); + in_whitespace = true; + } + continue; + } + + in_whitespace = false; + out.push_back(static_cast(ch)); + } + + return Trim(std::move(out)); +} + +std::string PrepareRegionContext(std::string_view region_context, + std::size_t max_chars) { + std::string normalized = CondenseWhitespace(std::string(region_context)); + if (normalized.size() <= max_chars) { + return normalized; + } + + normalized.resize(max_chars); + const std::size_t last_space = normalized.find_last_of(' '); + if (last_space != std::string::npos && last_space > max_chars / 2) { + normalized.resize(last_space); + } + + normalized += "..."; + return normalized; +} + +std::string StripCommonPrefix(std::string line) { + line = Trim(std::move(line)); + + if (!line.empty() && (line[0] == '-' || line[0] == '*')) { + line = Trim(line.substr(1)); + } else { + std::size_t i = 0; + while (i < line.size() && + std::isdigit(static_cast(line[i]))) { + ++i; + } + if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) { + line = Trim(line.substr(i + 1)); + } + } + + auto strip_label = [&line](const std::string& label) { + if (line.size() >= label.size()) { + bool matches = true; + for (std::size_t i = 0; i < label.size(); ++i) { + if (std::tolower(static_cast(line[i])) != + std::tolower(static_cast(label[i]))) { + matches = false; + break; + } + } + if (matches) { + line = Trim(line.substr(label.size())); + } + } + }; + + strip_label("name:"); + strip_label("brewery name:"); + strip_label("description:"); + strip_label("username:"); + strip_label("bio:"); + + return Trim(std::move(line)); +} + +std::pair +ParseTwoLineResponse(const std::string& raw, const std::string& error_message) { + std::string normalized = raw; + std::replace(normalized.begin(), normalized.end(), '\r', '\n'); + + std::vector lines; + std::stringstream stream(normalized); + std::string line; + while (std::getline(stream, line)) { + line = StripCommonPrefix(std::move(line)); + if (!line.empty()) + lines.push_back(std::move(line)); + } + + std::vector filtered; + for (auto &l : lines) { + std::string low = l; + std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + if (!l.empty() && l.front() == '<' && low.back() == '>') + continue; + if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) + continue; + filtered.push_back(std::move(l)); + } + + if (filtered.size() < 2) + throw std::runtime_error(error_message); + + std::string first = Trim(filtered.front()); + std::string second; + for (size_t i = 1; i < filtered.size(); ++i) { + if (!second.empty()) + second += ' '; + second += filtered[i]; + } + second = Trim(std::move(second)); + + if (first.empty() || second.empty()) + throw std::runtime_error(error_message); + return {first, second}; +} + +std::string ToChatPrompt(const llama_model *model, + const std::string& user_prompt) { + const char *tmpl = llama_model_chat_template(model, nullptr); + if (tmpl == nullptr) { + return user_prompt; + } + + const llama_chat_message message{"user", user_prompt.c_str()}; + + std::vector buffer(std::max(1024, user_prompt.size() * 4)); + int32_t required = + llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), + static_cast(buffer.size())); + + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + + if (required >= static_cast(buffer.size())) { + buffer.resize(static_cast(required) + 1); + required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), + static_cast(buffer.size())); + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + } + + return std::string(buffer.data(), static_cast(required)); +} + +std::string ToChatPrompt(const llama_model *model, + const std::string& system_prompt, + const std::string& user_prompt) { + const char *tmpl = llama_model_chat_template(model, nullptr); + if (tmpl == nullptr) { + return system_prompt + "\n\n" + user_prompt; + } + + const llama_chat_message messages[2] = {{"system", system_prompt.c_str()}, + {"user", user_prompt.c_str()}}; + + std::vector buffer(std::max( + 1024, (system_prompt.size() + user_prompt.size()) * 4)); + int32_t required = + llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), + static_cast(buffer.size())); + + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + + if (required >= static_cast(buffer.size())) { + buffer.resize(static_cast(required) + 1); + required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), + static_cast(buffer.size())); + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + } + + return std::string(buffer.data(), static_cast(required)); +} + +void AppendTokenPiece(const llama_vocab *vocab, llama_token token, + std::string& output) { + std::array buffer{}; + int32_t bytes = + llama_token_to_piece(vocab, token, buffer.data(), + static_cast(buffer.size()), 0, true); + + if (bytes < 0) { + std::vector dynamic_buffer(static_cast(-bytes)); + bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), + static_cast(dynamic_buffer.size()), 0, + true); + if (bytes < 0) { + throw std::runtime_error( + "LlamaGenerator: failed to decode sampled token piece"); + } + + output.append(dynamic_buffer.data(), static_cast(bytes)); + return; + } + + output.append(buffer.data(), static_cast(bytes)); +} + +bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) { + std::size_t start = std::string::npos; + int depth = 0; + bool in_string = false; + bool escaped = false; + + for (std::size_t i = 0; i < text.size(); ++i) { + const char ch = text[i]; + + if (in_string) { + if (escaped) { + escaped = false; + } else if (ch == '\\') { + escaped = true; + } else if (ch == '"') { + in_string = false; + } + continue; + } + + if (ch == '"') { + in_string = true; + continue; + } + + if (ch == '{') { + if (depth == 0) { + start = i; + } + ++depth; + continue; + } + + if (ch == '}') { + if (depth == 0) { + continue; + } + --depth; + if (depth == 0 && start != std::string::npos) { + json_out = text.substr(start, i - start + 1); + return true; + } + } + } + + return false; +} + +std::string ValidateBreweryJson(const std::string& raw, std::string& name_out, + std::string& description_out) { + auto validate_object = [&](const boost::json::value& jv, + std::string& error_out) -> bool { + if (!jv.is_object()) { + error_out = "JSON root must be an object"; + return false; + } + + const auto& obj = jv.get_object(); + if (!obj.contains("name") || !obj.at("name").is_string()) { + error_out = "JSON field 'name' is missing or not a string"; + return false; + } + + if (!obj.contains("description") || !obj.at("description").is_string()) { + error_out = "JSON field 'description' is missing or not a string"; + return false; + } + + name_out = Trim(std::string(obj.at("name").as_string().c_str())); + description_out = + Trim(std::string(obj.at("description").as_string().c_str())); + + if (name_out.empty()) { + error_out = "JSON field 'name' must not be empty"; + return false; + } + + if (description_out.empty()) { + error_out = "JSON field 'description' must not be empty"; + return false; + } + + std::string name_lower = name_out; + std::string description_lower = description_out; + std::transform( + name_lower.begin(), name_lower.end(), name_lower.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + std::transform(description_lower.begin(), description_lower.end(), + description_lower.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + + if (name_lower == "string" || description_lower == "string") { + error_out = "JSON appears to be a schema placeholder, not content"; + return false; + } + + error_out.clear(); + return true; + }; + + boost::system::error_code ec; + boost::json::value jv = boost::json::parse(raw, ec); + std::string validation_error; + if (ec) { + std::string extracted; + if (!ExtractFirstJsonObject(raw, extracted)) { + return "JSON parse error: " + ec.message(); + } + + ec.clear(); + jv = boost::json::parse(extracted, ec); + if (ec) { + return "JSON parse error: " + ec.message(); + } + + if (!validate_object(jv, validation_error)) { + return validation_error; + } + + return {}; + } + + if (!validate_object(jv, validation_error)) { + return validation_error; + } + + return {}; +} + +} // namespace + +// Forward declarations for helper functions exposed to other translation units +std::string PrepareRegionContextPublic(std::string_view region_context, + std::size_t max_chars) { + return PrepareRegionContext(region_context, max_chars); +} + +std::pair +ParseTwoLineResponsePublic(const std::string& raw, + const std::string& error_message) { + return ParseTwoLineResponse(raw, error_message); +} + +std::string ToChatPromptPublic(const llama_model *model, + const std::string& user_prompt) { + return ToChatPrompt(model, user_prompt); +} + +std::string ToChatPromptPublic(const llama_model *model, + const std::string& system_prompt, + const std::string& user_prompt) { + return ToChatPrompt(model, system_prompt, user_prompt); +} + +void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, + std::string& output) { + AppendTokenPiece(vocab, token, output); +} + +std::string ValidateBreweryJsonPublic(const std::string& raw, + std::string& name_out, + std::string& description_out) { + return ValidateBreweryJson(raw, name_out, description_out); +} diff --git a/pipeline/src/data_generation/llama/infer.cpp b/pipeline/src/data_generation/llama/infer.cpp new file mode 100644 index 0000000..b3a4e13 --- /dev/null +++ b/pipeline/src/data_generation/llama/infer.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include + +#include "llama.h" +#include + +#include "data_generation/llama_generator.h" +#include "data_generation/llama_generator_helpers.h" + +std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) { + if (model_ == nullptr || context_ == nullptr) + throw std::runtime_error("LlamaGenerator: model not loaded"); + + const llama_vocab *vocab = llama_model_get_vocab(model_); + if (vocab == nullptr) + throw std::runtime_error("LlamaGenerator: vocab unavailable"); + + llama_memory_clear(llama_get_memory(context_), true); + + const std::string formatted_prompt = ToChatPromptPublic(model_, prompt); + + std::vector prompt_tokens(formatted_prompt.size() + 8); + int32_t token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); + + if (token_count < 0) { + prompt_tokens.resize(static_cast(-token_count)); + token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); + } + + if (token_count < 0) + throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + + const int32_t n_ctx = static_cast(llama_n_ctx(context_)); + const int32_t n_batch = static_cast(llama_n_batch(context_)); + if (n_ctx <= 1 || n_batch <= 0) { + throw std::runtime_error("LlamaGenerator: invalid context or batch size"); + } + + const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); + int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); + prompt_budget = std::max(1, prompt_budget); + + prompt_tokens.resize(static_cast(token_count)); + if (token_count > prompt_budget) { + spdlog::warn( + "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " + "to fit n_batch/n_ctx limits", + token_count, prompt_budget); + prompt_tokens.resize(static_cast(prompt_budget)); + token_count = prompt_budget; + } + + const llama_batch prompt_batch = llama_batch_get_one( + prompt_tokens.data(), static_cast(prompt_tokens.size())); + if (llama_decode(context_, prompt_batch) != 0) + throw std::runtime_error("LlamaGenerator: prompt decode failed"); + + llama_sampler_chain_params sampler_params = + llama_sampler_chain_default_params(); + using SamplerPtr = + std::unique_ptr; + SamplerPtr sampler(llama_sampler_chain_init(sampler_params), + &llama_sampler_free); + if (!sampler) + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_temp(sampling_temperature_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_top_p(sampling_top_p_, 1)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_dist(sampling_seed_)); + + std::vector generated_tokens; + generated_tokens.reserve(static_cast(max_tokens)); + + for (int i = 0; i < effective_max_tokens; ++i) { + const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); + if (llama_vocab_is_eog(vocab, next)) + break; + generated_tokens.push_back(next); + llama_token token = next; + const llama_batch one_token_batch = llama_batch_get_one(&token, 1); + if (llama_decode(context_, one_token_batch) != 0) + throw std::runtime_error( + "LlamaGenerator: decode failed during generation"); + } + + std::string output; + for (const llama_token token : generated_tokens) + AppendTokenPiecePublic(vocab, token, output); + return output; +} + +std::string LlamaGenerator::Infer(const std::string& system_prompt, + const std::string& prompt, int max_tokens) { + if (model_ == nullptr || context_ == nullptr) + throw std::runtime_error("LlamaGenerator: model not loaded"); + + const llama_vocab *vocab = llama_model_get_vocab(model_); + if (vocab == nullptr) + throw std::runtime_error("LlamaGenerator: vocab unavailable"); + + llama_memory_clear(llama_get_memory(context_), true); + + const std::string formatted_prompt = + ToChatPromptPublic(model_, system_prompt, prompt); + + std::vector prompt_tokens(formatted_prompt.size() + 8); + int32_t token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); + + if (token_count < 0) { + prompt_tokens.resize(static_cast(-token_count)); + token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); + } + + if (token_count < 0) + throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + + const int32_t n_ctx = static_cast(llama_n_ctx(context_)); + const int32_t n_batch = static_cast(llama_n_batch(context_)); + if (n_ctx <= 1 || n_batch <= 0) { + throw std::runtime_error("LlamaGenerator: invalid context or batch size"); + } + + const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); + int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); + prompt_budget = std::max(1, prompt_budget); + + prompt_tokens.resize(static_cast(token_count)); + if (token_count > prompt_budget) { + spdlog::warn( + "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " + "to fit n_batch/n_ctx limits", + token_count, prompt_budget); + prompt_tokens.resize(static_cast(prompt_budget)); + token_count = prompt_budget; + } + + const llama_batch prompt_batch = llama_batch_get_one( + prompt_tokens.data(), static_cast(prompt_tokens.size())); + if (llama_decode(context_, prompt_batch) != 0) + throw std::runtime_error("LlamaGenerator: prompt decode failed"); + + llama_sampler_chain_params sampler_params = + llama_sampler_chain_default_params(); + using SamplerPtr = + std::unique_ptr; + SamplerPtr sampler(llama_sampler_chain_init(sampler_params), + &llama_sampler_free); + if (!sampler) + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_temp(sampling_temperature_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_top_p(sampling_top_p_, 1)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_dist(sampling_seed_)); + + std::vector generated_tokens; + generated_tokens.reserve(static_cast(max_tokens)); + + for (int i = 0; i < effective_max_tokens; ++i) { + const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); + if (llama_vocab_is_eog(vocab, next)) + break; + generated_tokens.push_back(next); + llama_token token = next; + const llama_batch one_token_batch = llama_batch_get_one(&token, 1); + if (llama_decode(context_, one_token_batch) != 0) + throw std::runtime_error( + "LlamaGenerator: decode failed during generation"); + } + + std::string output; + for (const llama_token token : generated_tokens) + AppendTokenPiecePublic(vocab, token, output); + return output; +} diff --git a/pipeline/src/data_generation/llama/load.cpp b/pipeline/src/data_generation/llama/load.cpp new file mode 100644 index 0000000..c38808b --- /dev/null +++ b/pipeline/src/data_generation/llama/load.cpp @@ -0,0 +1,42 @@ +#include +#include + +#include "llama.h" +#include + +#include "data_generation/llama_generator.h" + +void LlamaGenerator::Load(const std::string& model_path) { + if (model_path.empty()) + throw std::runtime_error("LlamaGenerator: model path must not be empty"); + + if (context_ != nullptr) { + llama_free(context_); + context_ = nullptr; + } + if (model_ != nullptr) { + llama_model_free(model_); + model_ = nullptr; + } + + llama_backend_init(); + + llama_model_params model_params = llama_model_default_params(); + model_ = llama_model_load_from_file(model_path.c_str(), model_params); + if (model_ == nullptr) { + throw std::runtime_error( + "LlamaGenerator: failed to load model from path: " + model_path); + } + + llama_context_params context_params = llama_context_default_params(); + context_params.n_ctx = 2048; + + context_ = llama_init_from_model(model_, context_params); + if (context_ == nullptr) { + llama_model_free(model_); + model_ = nullptr; + throw std::runtime_error("LlamaGenerator: failed to create context"); + } + + spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); +} diff --git a/pipeline/src/data_generation/llama/set_sampling_options.cpp b/pipeline/src/data_generation/llama/set_sampling_options.cpp new file mode 100644 index 0000000..3898fca --- /dev/null +++ b/pipeline/src/data_generation/llama/set_sampling_options.cpp @@ -0,0 +1,26 @@ +#include + +#include "llama.h" + +#include "data_generation/llama_generator.h" + +void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, + int seed) { + if (temperature < 0.0f) { + throw std::runtime_error( + "LlamaGenerator: sampling temperature must be >= 0"); + } + if (!(top_p > 0.0f && top_p <= 1.0f)) { + throw std::runtime_error( + "LlamaGenerator: sampling top-p must be in (0, 1]"); + } + if (seed < -1) { + throw std::runtime_error( + "LlamaGenerator: seed must be >= 0, or -1 for random"); + } + + sampling_temperature_ = temperature; + sampling_top_p_ = top_p; + sampling_seed_ = (seed < 0) ? static_cast(LLAMA_DEFAULT_SEED) + : static_cast(seed); +} diff --git a/pipeline/src/data_generation/llama_generator.cpp b/pipeline/src/data_generation/llama_generator.cpp deleted file mode 100644 index d35c65b..0000000 --- a/pipeline/src/data_generation/llama_generator.cpp +++ /dev/null @@ -1,734 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llama.h" -#include -#include - -#include "data_generation/llama_generator.h" - -namespace { - -std::string trim(std::string value) { - auto notSpace = [](unsigned char ch) { return !std::isspace(ch); }; - - value.erase(value.begin(), - std::find_if(value.begin(), value.end(), notSpace)); - value.erase(std::find_if(value.rbegin(), value.rend(), notSpace).base(), - value.end()); - - return value; -} - -std::string CondenseWhitespace(std::string text) { - std::string out; - out.reserve(text.size()); - - bool inWhitespace = false; - for (unsigned char ch : text) { - if (std::isspace(ch)) { - if (!inWhitespace) { - out.push_back(' '); - inWhitespace = true; - } - continue; - } - - inWhitespace = false; - out.push_back(static_cast(ch)); - } - - return trim(std::move(out)); -} - -std::string PrepareRegionContext(std::string_view regionContext, - std::size_t maxChars = 700) { - std::string normalized = CondenseWhitespace(std::string(regionContext)); - if (normalized.size() <= maxChars) { - return normalized; - } - - normalized.resize(maxChars); - const std::size_t lastSpace = normalized.find_last_of(' '); - if (lastSpace != std::string::npos && lastSpace > maxChars / 2) { - normalized.resize(lastSpace); - } - - normalized += "..."; - return normalized; -} - -std::string stripCommonPrefix(std::string line) { - line = trim(std::move(line)); - - if (!line.empty() && (line[0] == '-' || line[0] == '*')) { - line = trim(line.substr(1)); - } else { - std::size_t i = 0; - while (i < line.size() && - std::isdigit(static_cast(line[i]))) { - ++i; - } - if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) { - line = trim(line.substr(i + 1)); - } - } - - auto stripLabel = [&line](const std::string &label) { - if (line.size() >= label.size()) { - bool matches = true; - for (std::size_t i = 0; i < label.size(); ++i) { - if (std::tolower(static_cast(line[i])) != - std::tolower(static_cast(label[i]))) { - matches = false; - break; - } - } - if (matches) { - line = trim(line.substr(label.size())); - } - } - }; - - stripLabel("name:"); - stripLabel("brewery name:"); - stripLabel("description:"); - stripLabel("username:"); - stripLabel("bio:"); - - return trim(std::move(line)); -} - -std::pair -parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) { - std::string normalized = raw; - std::replace(normalized.begin(), normalized.end(), '\r', '\n'); - - std::vector lines; - std::stringstream stream(normalized); - std::string line; - while (std::getline(stream, line)) { - line = stripCommonPrefix(std::move(line)); - if (!line.empty()) - lines.push_back(std::move(line)); - } - - std::vector filtered; - for (auto &l : lines) { - std::string low = l; - std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { - return static_cast(std::tolower(c)); - }); - if (!l.empty() && l.front() == '<' && low.back() == '>') - continue; - if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) - continue; - filtered.push_back(std::move(l)); - } - - if (filtered.size() < 2) - throw std::runtime_error(errorMessage); - - std::string first = trim(filtered.front()); - std::string second; - for (size_t i = 1; i < filtered.size(); ++i) { - if (!second.empty()) - second += ' '; - second += filtered[i]; - } - second = trim(std::move(second)); - - if (first.empty() || second.empty()) - throw std::runtime_error(errorMessage); - return {first, second}; -} - -std::string toChatPrompt(const llama_model *model, - const std::string &userPrompt) { - const char *tmpl = llama_model_chat_template(model, nullptr); - if (tmpl == nullptr) { - return userPrompt; - } - - const llama_chat_message message{"user", userPrompt.c_str()}; - - std::vector buffer(std::max(1024, userPrompt.size() * 4)); - int32_t required = - llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), - static_cast(buffer.size())); - - if (required < 0) { - throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - - if (required >= static_cast(buffer.size())) { - buffer.resize(static_cast(required) + 1); - required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), - static_cast(buffer.size())); - if (required < 0) { - throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - } - - return std::string(buffer.data(), static_cast(required)); -} - -std::string toChatPrompt(const llama_model *model, - const std::string &system_prompt, - const std::string &userPrompt) { - const char *tmpl = llama_model_chat_template(model, nullptr); - if (tmpl == nullptr) { - return system_prompt + "\n\n" + userPrompt; - } - - const llama_chat_message messages[2] = {{"system", system_prompt.c_str()}, - {"user", userPrompt.c_str()}}; - - std::vector buffer(std::max( - 1024, (systemPrompt.size() + userPrompt.size()) * 4)); - int32_t required = - llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), - static_cast(buffer.size())); - - if (required < 0) { - throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - - if (required >= static_cast(buffer.size())) { - buffer.resize(static_cast(required) + 1); - required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), - static_cast(buffer.size())); - if (required < 0) { - throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - } - - return std::string(buffer.data(), static_cast(required)); -} - -void appendTokenPiece(const llama_vocab *vocab, llama_token token, - std::string &output) { - std::array buffer{}; - int32_t bytes = - llama_token_to_piece(vocab, token, buffer.data(), - static_cast(buffer.size()), 0, true); - - if (bytes < 0) { - std::vector dynamicBuffer(static_cast(-bytes)); - bytes = llama_token_to_piece(vocab, token, dynamicBuffer.data(), - static_cast(dynamicBuffer.size()), 0, - true); - if (bytes < 0) { - throw std::runtime_error( - "LlamaGenerator: failed to decode sampled token piece"); - } - - output.append(dynamicBuffer.data(), static_cast(bytes)); - return; - } - - output.append(buffer.data(), static_cast(bytes)); -} - -bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) { - std::size_t start = std::string::npos; - int depth = 0; - bool inString = false; - bool escaped = false; - - for (std::size_t i = 0; i < text.size(); ++i) { - const char ch = text[i]; - - if (inString) { - if (escaped) { - escaped = false; - } else if (ch == '\\') { - escaped = true; - } else if (ch == '"') { - inString = false; - } - continue; - } - - if (ch == '"') { - inString = true; - continue; - } - - if (ch == '{') { - if (depth == 0) { - start = i; - } - ++depth; - continue; - } - - if (ch == '}') { - if (depth == 0) { - continue; - } - --depth; - if (depth == 0 && start != std::string::npos) { - jsonOut = text.substr(start, i - start + 1); - return true; - } - } - } - - return false; -} - -std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut, - std::string &descriptionOut) { - auto validateObject = [&](const boost::json::value &jv, - std::string &errorOut) -> bool { - if (!jv.is_object()) { - errorOut = "JSON root must be an object"; - return false; - } - - const auto &obj = jv.get_object(); - if (!obj.contains("name") || !obj.at("name").is_string()) { - errorOut = "JSON field 'name' is missing or not a string"; - return false; - } - - if (!obj.contains("description") || !obj.at("description").is_string()) { - errorOut = "JSON field 'description' is missing or not a string"; - return false; - } - - nameOut = trim(std::string(obj.at("name").as_string().c_str())); - descriptionOut = - trim(std::string(obj.at("description").as_string().c_str())); - - if (nameOut.empty()) { - errorOut = "JSON field 'name' must not be empty"; - return false; - } - - if (descriptionOut.empty()) { - errorOut = "JSON field 'description' must not be empty"; - return false; - } - - std::string nameLower = nameOut; - std::string descriptionLower = descriptionOut; - std::transform( - nameLower.begin(), nameLower.end(), nameLower.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - std::transform(descriptionLower.begin(), descriptionLower.end(), - descriptionLower.begin(), [](unsigned char c) { - return static_cast(std::tolower(c)); - }); - - if (nameLower == "string" || descriptionLower == "string") { - errorOut = "JSON appears to be a schema placeholder, not content"; - return false; - } - - errorOut.clear(); - return true; - }; - - boost::system::error_code ec; - boost::json::value jv = boost::json::parse(raw, ec); - std::string validationError; - if (ec) { - std::string extracted; - if (!extractFirstJsonObject(raw, extracted)) { - return "JSON parse error: " + ec.message(); - } - - ec.clear(); - jv = boost::json::parse(extracted, ec); - if (ec) { - return "JSON parse error: " + ec.message(); - } - - if (!validateObject(jv, validationError)) { - return validationError; - } - - return {}; - } - - if (!validateObject(jv, validationError)) { - return validationError; - } - - return {}; -} -} // namespace - -LlamaGenerator::~LlamaGenerator() { - if (context_ != nullptr) { - llama_free(context_); - context_ = nullptr; - } - - if (model_ != nullptr) { - llama_model_free(model_); - model_ = nullptr; - } - - llama_backend_free(); -} - -void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, - int seed) { - if (temperature < 0.0f) { - throw std::runtime_error( - "LlamaGenerator: sampling temperature must be >= 0"); - } - if (!(top_p > 0.0f && top_p <= 1.0f)) { - throw std::runtime_error( - "LlamaGenerator: sampling top-p must be in (0, 1]"); - } - if (seed < -1) { - throw std::runtime_error( - "LlamaGenerator: seed must be >= 0, or -1 for random"); - } - - sampling_temperature_ = temperature; - sampling_top_p_ = top_p; - sampling_seed_ = (seed < 0) ? static_cast(LLAMA_DEFAULT_SEED) - : static_cast(seed); -} - -void LlamaGenerator::Load(const std::string &model_path) { - if (model_path.empty()) - throw std::runtime_error("LlamaGenerator: model path must not be empty"); - - if (context_ != nullptr) { - llama_free(context_); - context_ = nullptr; - } - if (model_ != nullptr) { - llama_model_free(model_); - model_ = nullptr; - } - - llama_backend_init(); - - llama_model_params model_params = llama_model_default_params(); - model_ = llama_model_load_from_file(model_path.c_str(), model_params); - if (model_ == nullptr) { - throw std::runtime_error( - "LlamaGenerator: failed to load model from path: " + model_path); - } - - llama_context_params context_params = llama_context_default_params(); - context_params.n_ctx = 2048; - - context_ = llama_init_from_model(model_, context_params); - if (context_ == nullptr) { - llama_model_free(model_); - model_ = nullptr; - throw std::runtime_error("LlamaGenerator: failed to create context"); - } - - spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); -} - -std::string LlamaGenerator::Infer(const std::string &prompt, int max_tokens) { - if (model_ == nullptr || context_ == nullptr) - throw std::runtime_error("LlamaGenerator: model not loaded"); - - const llama_vocab *vocab = llama_model_get_vocab(model_); - if (vocab == nullptr) - throw std::runtime_error("LlamaGenerator: vocab unavailable"); - - llama_memory_clear(llama_get_memory(context_), true); - - const std::string formatted_prompt = toChatPrompt(model_, prompt); - - std::vector promptTokens(formatted_prompt.size() + 8); - int32_t tokenCount = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), promptTokens.data(), - static_cast(promptTokens.size()), true, true); - - if (tokenCount < 0) { - promptTokens.resize(static_cast(-tokenCount)); - tokenCount = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), promptTokens.data(), - static_cast(promptTokens.size()), true, true); - } - - if (tokenCount < 0) - throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); - - const int32_t nCtx = static_cast(llama_n_ctx(context_)); - const int32_t nBatch = static_cast(llama_n_batch(context_)); - if (nCtx <= 1 || nBatch <= 0) { - throw std::runtime_error("LlamaGenerator: invalid context or batch size"); - } - - const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1)); - const int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens); - prompt_budget = std::max(1, prompt_budget); - - promptTokens.resize(static_cast(tokenCount)); - if (tokenCount > prompt_budget) { - spdlog::warn( - "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " - "to fit n_batch/n_ctx limits", - tokenCount, prompt_budget); - promptTokens.resize(static_cast(prompt_budget)); - tokenCount = prompt_budget; - } - - const llama_batch promptBatch = llama_batch_get_one( - promptTokens.data(), static_cast(promptTokens.size())); - if (llama_decode(context_, promptBatch) != 0) - throw std::runtime_error("LlamaGenerator: prompt decode failed"); - - llama_sampler_chain_params sampler_params = - llama_sampler_chain_default_params(); - using SamplerPtr = - std::unique_ptr; - SamplerPtr sampler(llama_sampler_chain_init(sampler_params), - &llama_sampler_free); - if (!sampler) - throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_temp(sampling_temperature_)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_top_p(sampling_top_p_, 1)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_dist(sampling_seed_)); - - std::vector generated_tokens; - generated_tokens.reserve(static_cast(max_tokens)); - - for (int i = 0; i < effective_max_tokens; ++i) { - const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); - if (llama_vocab_is_eog(vocab, next)) - break; - generated_tokens.push_back(next); - llama_token token = next; - const llama_batch one_token_batch = llama_batch_get_one(&token, 1); - if (llama_decode(context_, one_token_batch) != 0) - throw std::runtime_error( - "LlamaGenerator: decode failed during generation"); - } - - std::string output; - for (const llama_token token : generated_tokens) - appendTokenPiece(vocab, token, output); - return output; -} - -std::string LlamaGenerator::Infer(const std::string &system_prompt, - const std::string &prompt, int max_tokens) { - if (model_ == nullptr || context_ == nullptr) - throw std::runtime_error("LlamaGenerator: model not loaded"); - - const llama_vocab *vocab = llama_model_get_vocab(model_); - if (vocab == nullptr) - throw std::runtime_error("LlamaGenerator: vocab unavailable"); - - llama_memory_clear(llama_get_memory(context_), true); - - const std::string formatted_prompt = - toChatPrompt(model_, system_prompt, prompt); - - std::vector promptTokens(formatted_prompt.size() + 8); - int32_t tokenCount = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), promptTokens.data(), - static_cast(promptTokens.size()), true, true); - - if (tokenCount < 0) { - promptTokens.resize(static_cast(-tokenCount)); - tokenCount = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), promptTokens.data(), - static_cast(promptTokens.size()), true, true); - } - - if (tokenCount < 0) - throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); - - const int32_t nCtx = static_cast(llama_n_ctx(context_)); - const int32_t nBatch = static_cast(llama_n_batch(context_)); - if (nCtx <= 1 || nBatch <= 0) { - throw std::runtime_error("LlamaGenerator: invalid context or batch size"); - } - - const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1)); - int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens); - prompt_budget = std::max(1, prompt_budget); - - promptTokens.resize(static_cast(tokenCount)); - if (tokenCount > prompt_budget) { - spdlog::warn( - "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " - "to fit n_batch/n_ctx limits", - tokenCount, prompt_budget); - promptTokens.resize(static_cast(prompt_budget)); - tokenCount = prompt_budget; - } - - const llama_batch promptBatch = llama_batch_get_one( - promptTokens.data(), static_cast(promptTokens.size())); - if (llama_decode(context_, promptBatch) != 0) - throw std::runtime_error("LlamaGenerator: prompt decode failed"); - - llama_sampler_chain_params sampler_params = - llama_sampler_chain_default_params(); - using SamplerPtr = - std::unique_ptr; - SamplerPtr sampler(llama_sampler_chain_init(sampler_params), - &llama_sampler_free); - if (!sampler) - throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_temp(sampling_temperature_)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_top_p(sampling_top_p_, 1)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_dist(sampling_seed_)); - - std::vector generated_tokens; - generated_tokens.reserve(static_cast(max_tokens)); - - for (int i = 0; i < effective_max_tokens; ++i) { - const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); - if (llama_vocab_is_eog(vocab, next)) - break; - generated_tokens.push_back(next); - llama_token token = next; - const llama_batch one_token_batch = llama_batch_get_one(&token, 1); - if (llama_decode(context_, one_token_batch) != 0) - throw std::runtime_error( - "LlamaGenerator: decode failed during generation"); - } - - std::string output; - for (const llama_token token : generated_tokens) - appendTokenPiece(vocab, token, output); - return output; -} - -BreweryResult -LlamaGenerator::GenerateBrewery(const std::string &city_name, - const std::string &country_name, - const std::string ®ion_context) { - const std::string safe_region_context = PrepareRegionContext(region_context); - - const std::string system_prompt = - "You are a copywriter for a craft beer travel guide. " - "Your writing is vivid, specific to place, and avoids generic beer " - "cliches. " - "You must output ONLY valid JSON. " - "The JSON schema must be exactly: {\"name\": \"string\", " - "\"description\": \"string\"}. " - "Do not include markdown formatting or backticks."; - - std::string prompt = - "Write a brewery name and place-specific description for a craft " - "brewery in " + - city_name + - (country_name.empty() ? std::string("") - : std::string(", ") + country_name) + - (safe_region_context.empty() - ? std::string(".") - : std::string(". Regional context: ") + safe_region_context); - - const int maxAttempts = 3; - std::string raw; - std::string lastError; - for (int attempt = 0; attempt < maxAttempts; ++attempt) { - raw = Infer(system_prompt, prompt, 384); - spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, - raw); - - std::string name; - std::string description; - const std::string validationError = - ValidateBreweryJson(raw, name, description); - if (validationError.empty()) { - return {std::move(name), std::move(description)}; - } - - lastError = validationError; - spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", - attempt + 1, validationError); - - prompt = "Your previous response was invalid. Error: " + validationError + - "\nReturn ONLY valid JSON with this exact schema: " - "{\"name\": \"string\", \"description\": \"string\"}." - "\nDo not include markdown, comments, or extra keys." - "\n\nLocation: " + - city_name + - (country_name.empty() ? std::string("") - : std::string(", ") + country_name) + - (safe_region_context.empty() - ? std::string("") - : std::string("\nRegional context: ") + safe_region_context); - } - - spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: " - "{}", - maxAttempts, lastError.empty() ? raw : lastError); - throw std::runtime_error("LlamaGenerator: malformed brewery response"); -} - -UserResult LlamaGenerator::GenerateUser(const std::string &locale) { - const std::string system_prompt = - "You generate plausible social media profiles for craft beer " - "enthusiasts. " - "Respond with exactly two lines: " - "the first line is a username (lowercase, no spaces, 8-20 characters), " - "the second line is a one-sentence bio (20-40 words). " - "The profile should feel consistent with the locale. " - "No preamble, no labels."; - - std::string prompt = - "Generate a craft beer enthusiast profile. Locale: " + locale; - - const int maxAttempts = 3; - std::string raw; - for (int attempt = 0; attempt < maxAttempts; ++attempt) { - raw = Infer(system_prompt, prompt, 128); - spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}", - attempt + 1, raw); - - try { - auto [username, bio] = - parseTwoLineResponse(raw, "LlamaGenerator: malformed user response"); - - username.erase( - std::remove_if(username.begin(), username.end(), - [](unsigned char ch) { return std::isspace(ch); }), - username.end()); - - if (username.empty() || bio.empty()) { - throw std::runtime_error("LlamaGenerator: malformed user response"); - } - - if (bio.size() > 200) - bio = bio.substr(0, 200); - - return {username, bio}; - } catch (const std::exception &e) { - spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}", - attempt + 1, e.what()); - } - } - - spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}", - maxAttempts, raw); - throw std::runtime_error("LlamaGenerator: malformed user response"); -} diff --git a/pipeline/src/wikipedia/wikipedia_service.cpp b/pipeline/src/wikipedia/wikipedia_service.cpp index e1bd82d..c42bf27 100644 --- a/pipeline/src/wikipedia/wikipedia_service.cpp +++ b/pipeline/src/wikipedia/wikipedia_service.cpp @@ -2,7 +2,7 @@ #include #include -WikipediaService::WikipediaService(std::shared_ptr client) +WikipediaService::WikipediaService(std::shared_ptr client) : client_(std::move(client)) {} std::string WikipediaService::FetchExtract(std::string_view query) {