diff --git a/pipeline/.clang-format b/pipeline/.clang-format index f608170..4ecd342 100644 --- a/pipeline/.clang-format +++ b/pipeline/.clang-format @@ -1,10 +1,5 @@ --- BasedOnStyle: Google -Standard: c++23 -ColumnLimit: 100 -IndentWidth: 2 -DerivePointerAlignment: false -PointerAlignment: Left -SortIncludes: true -IncludeBlocks: Preserve +ColumnLimit: 80 +IndentWidth: 3 ... diff --git a/pipeline/CMakeLists.txt b/pipeline/CMakeLists.txt index 3e2ec6d..0d7e0a1 100644 --- a/pipeline/CMakeLists.txt +++ b/pipeline/CMakeLists.txt @@ -90,7 +90,11 @@ set(PIPELINE_SOURCES src/data_generation/llama/generate_brewery.cpp src/data_generation/llama/generate_user.cpp src/data_generation/llama/helpers.cpp - src/data_generation/mock_generator.cpp + src/data_generation/mock/data.cpp + src/data_generation/mock/deterministic_hash.cpp + src/data_generation/mock/load.cpp + src/data_generation/mock/generate_brewery.cpp + src/data_generation/mock/generate_user.cpp src/json_handling/stream_parser.cpp src/wikipedia/wikipedia_service.cpp src/main.cpp diff --git a/pipeline/includes/data_generation/data_downloader.h b/pipeline/includes/data_generation/data_downloader.h index ded7581..cf2de92 100644 --- a/pipeline/includes/data_generation/data_downloader.h +++ b/pipeline/includes/data_generation/data_downloader.h @@ -9,22 +9,23 @@ /// @brief Downloads and caches source geography JSON payloads. class DataDownloader { -public: - /// @brief Initializes global curl state used by this downloader. - explicit DataDownloader(std::shared_ptr web_client); + public: + /// @brief Initializes global curl state used by this downloader. + explicit DataDownloader(std::shared_ptr web_client); - /// @brief Cleans up global curl state. - ~DataDownloader(); + /// @brief Cleans up global curl state. + ~DataDownloader(); - /// @brief Returns a local JSON path, downloading it when cache is missing. - std::string DownloadCountriesDatabase( - const std::string &cache_path, - const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export - ); + /// @brief Returns a local JSON path, downloading it when cache is missing. + std::string DownloadCountriesDatabase( + const std::string& cache_path, + const std::string& commit = + "c5eb7772" // Stable commit: 2026-03-28 export + ); -private: - static bool FileExists(const std::string &file_path); - std::shared_ptr web_client_; + private: + static bool FileExists(const std::string& file_path); + std::shared_ptr web_client_; }; #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_ diff --git a/pipeline/includes/data_generation/data_generator.h b/pipeline/includes/data_generation/data_generator.h index 6f5a315..18a2204 100644 --- a/pipeline/includes/data_generation/data_generator.h +++ b/pipeline/includes/data_generation/data_generator.h @@ -4,26 +4,26 @@ #include struct BreweryResult { - std::string name; - std::string description; + std::string name; + std::string description; }; struct UserResult { - std::string username; - std::string bio; + std::string username; + std::string bio; }; class DataGenerator { -public: - virtual ~DataGenerator() = default; + public: + virtual ~DataGenerator() = default; - virtual void Load(const std::string &model_path) = 0; + virtual void Load(const std::string& model_path) = 0; - virtual BreweryResult GenerateBrewery(const std::string &city_name, - const std::string &country_name, - const std::string ®ion_context) = 0; + virtual BreweryResult GenerateBrewery(const std::string& city_name, + const std::string& country_name, + const std::string& region_context) = 0; - virtual UserResult GenerateUser(const std::string &locale) = 0; + virtual UserResult GenerateUser(const std::string& locale) = 0; }; #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_ diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index 4548205..31d29b1 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -10,32 +10,32 @@ struct llama_model; struct llama_context; class LlamaGenerator final : public DataGenerator { -public: - LlamaGenerator() = default; - ~LlamaGenerator() override; + public: + LlamaGenerator() = default; + ~LlamaGenerator() override; - void SetSamplingOptions(float temperature, float top_p, int seed = -1); + void SetSamplingOptions(float temperature, float top_p, int seed = -1); - void Load(const std::string &model_path) override; - BreweryResult GenerateBrewery(const std::string &city_name, - const std::string &country_name, - const std::string ®ion_context) override; - UserResult GenerateUser(const std::string &locale) override; + void Load(const std::string& model_path) override; + BreweryResult GenerateBrewery(const std::string& city_name, + const std::string& country_name, + const std::string& region_context) override; + UserResult GenerateUser(const std::string& locale) override; -private: - std::string Infer(const std::string &prompt, int max_tokens = 10000); - // Overload that allows passing a system message separately so chat-capable - // models receive a proper system role instead of having the system text - // concatenated into the user prompt (helps avoid revealing internal - // reasoning or instructions in model output). - std::string Infer(const std::string &system_prompt, const std::string &prompt, - int max_tokens = 10000); + private: + std::string Infer(const std::string& prompt, int max_tokens = 10000); + // Overload that allows passing a system message separately so chat-capable + // models receive a proper system role instead of having the system text + // concatenated into the user prompt (helps avoid revealing internal + // reasoning or instructions in model output). + std::string Infer(const std::string& system_prompt, + const std::string& prompt, int max_tokens = 10000); - llama_model *model_ = nullptr; - llama_context *context_ = nullptr; - float sampling_temperature_ = 0.8f; - float sampling_top_p_ = 0.92f; - uint32_t sampling_seed_ = 0xFFFFFFFFu; + llama_model* model_ = nullptr; + llama_context* context_ = nullptr; + float sampling_temperature_ = 0.8f; + float sampling_top_p_ = 0.92f; + uint32_t sampling_seed_ = 0xFFFFFFFFu; }; #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_ diff --git a/pipeline/includes/data_generation/llama_generator_helpers.h b/pipeline/includes/data_generation/llama_generator_helpers.h index 11331de..5db0e48 100644 --- a/pipeline/includes/data_generation/llama_generator_helpers.h +++ b/pipeline/includes/data_generation/llama_generator_helpers.h @@ -12,18 +12,17 @@ typedef int llama_token; std::string PrepareRegionContextPublic(std::string_view region_context, std::size_t max_chars = 700); -std::pair -ParseTwoLineResponsePublic(const std::string& raw, - const std::string& error_message); +std::pair ParseTwoLineResponsePublic( + const std::string& raw, const std::string& error_message); -std::string ToChatPromptPublic(const llama_model *model, +std::string ToChatPromptPublic(const llama_model* model, const std::string& user_prompt); -std::string ToChatPromptPublic(const llama_model *model, +std::string ToChatPromptPublic(const llama_model* model, const std::string& system_prompt, const std::string& user_prompt); -void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, +void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token, std::string& output); std::string ValidateBreweryJsonPublic(const std::string& raw, diff --git a/pipeline/includes/data_generation/mock_generator.h b/pipeline/includes/data_generation/mock_generator.h index efe4ad0..69cf1bd 100644 --- a/pipeline/includes/data_generation/mock_generator.h +++ b/pipeline/includes/data_generation/mock_generator.h @@ -1,27 +1,28 @@ #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ -#include "data_generation/data_generator.h" #include #include +#include "data_generation/data_generator.h" + class MockGenerator final : public DataGenerator { -public: - void Load(const std::string &model_path) override; - BreweryResult GenerateBrewery(const std::string &city_name, - const std::string &country_name, - const std::string ®ion_context) override; - UserResult GenerateUser(const std::string &locale) override; + public: + void Load(const std::string& model_path) override; + BreweryResult GenerateBrewery(const std::string& city_name, + const std::string& country_name, + const std::string& region_context) override; + UserResult GenerateUser(const std::string& locale) override; -private: - static std::size_t DeterministicHash(const std::string &a, - const std::string &b); + private: + static std::size_t DeterministicHash(const std::string& a, + const std::string& b); - static const std::vector kBreweryAdjectives; - static const std::vector kBreweryNouns; - static const std::vector kBreweryDescriptions; - static const std::vector kUsernames; - static const std::vector kBios; + static const std::vector kBreweryAdjectives; + static const std::vector kBreweryNouns; + static const std::vector kBreweryDescriptions; + static const std::vector kUsernames; + static const std::vector kBios; }; #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ diff --git a/pipeline/includes/database/database.h b/pipeline/includes/database/database.h index 97d91e7..03307fe 100644 --- a/pipeline/includes/database/database.h +++ b/pipeline/includes/database/database.h @@ -1,83 +1,84 @@ #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ -#include #include + +#include #include #include struct Country { - /// @brief Country identifier from the source dataset. - int id; - /// @brief Country display name. - std::string name; - /// @brief ISO 3166-1 alpha-2 code. - std::string iso2; - /// @brief ISO 3166-1 alpha-3 code. - std::string iso3; + /// @brief Country identifier from the source dataset. + int id; + /// @brief Country display name. + std::string name; + /// @brief ISO 3166-1 alpha-2 code. + std::string iso2; + /// @brief ISO 3166-1 alpha-3 code. + std::string iso3; }; struct State { - /// @brief State or province identifier from the source dataset. - int id; - /// @brief State or province display name. - std::string name; - /// @brief State or province short code. - std::string iso2; - /// @brief Parent country identifier. - int country_id; + /// @brief State or province identifier from the source dataset. + int id; + /// @brief State or province display name. + std::string name; + /// @brief State or province short code. + std::string iso2; + /// @brief Parent country identifier. + int country_id; }; struct City { - /// @brief City identifier from the source dataset. - int id; - /// @brief City display name. - std::string name; - /// @brief Parent country identifier. - int country_id; + /// @brief City identifier from the source dataset. + int id; + /// @brief City display name. + std::string name; + /// @brief Parent country identifier. + int country_id; }; /// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks. class SqliteDatabase { -private: - sqlite3 *db_ = nullptr; - std::mutex db_mutex_; + private: + sqlite3* db_ = nullptr; + std::mutex db_mutex_; - void InitializeSchema(); + void InitializeSchema(); -public: - /// @brief Closes the SQLite connection if initialized. - ~SqliteDatabase(); + public: + /// @brief Closes the SQLite connection if initialized. + ~SqliteDatabase(); - /// @brief Opens the SQLite database at db_path and creates schema objects. - void Initialize(const std::string &db_path = ":memory:"); + /// @brief Opens the SQLite database at db_path and creates schema objects. + void Initialize(const std::string& db_path = ":memory:"); - /// @brief Starts a database transaction for batched writes. - void BeginTransaction(); + /// @brief Starts a database transaction for batched writes. + void BeginTransaction(); - /// @brief Commits the active database transaction. - void CommitTransaction(); + /// @brief Commits the active database transaction. + void CommitTransaction(); - /// @brief Inserts a country row. - void InsertCountry(int id, const std::string &name, const std::string &iso2, - const std::string &iso3); + /// @brief Inserts a country row. + void InsertCountry(int id, const std::string& name, const std::string& iso2, + const std::string& iso3); - /// @brief Inserts a state row linked to a country. - void InsertState(int id, int country_id, const std::string &name, - const std::string &iso2); + /// @brief Inserts a state row linked to a country. + void InsertState(int id, int country_id, const std::string& name, + const std::string& iso2); - /// @brief Inserts a city row linked to state and country. - void InsertCity(int id, int state_id, int country_id, const std::string &name, - double latitude, double longitude); + /// @brief Inserts a city row linked to state and country. + void InsertCity(int id, int state_id, int country_id, + const std::string& name, double latitude, double longitude); - /// @brief Returns city records including parent country id. - std::vector QueryCities(); + /// @brief Returns city records including parent country id. + std::vector QueryCities(); - /// @brief Returns countries with optional row limit. - std::vector QueryCountries(int limit = 0); + /// @brief Returns countries with optional row limit. + std::vector QueryCountries(int limit = 0); - /// @brief Returns states with optional row limit. - std::vector QueryStates(int limit = 0); + /// @brief Returns states with optional row limit. + std::vector QueryStates(int limit = 0); }; #endif // BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ diff --git a/pipeline/includes/json_handling/json_loader.h b/pipeline/includes/json_handling/json_loader.h index 9f5d2e0..d6fca00 100644 --- a/pipeline/includes/json_handling/json_loader.h +++ b/pipeline/includes/json_handling/json_loader.h @@ -1,15 +1,17 @@ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ +#include + #include "database/database.h" #include "json_handling/stream_parser.h" -#include /// @brief Loads world-city JSON data into SQLite through streaming parsing. class JsonLoader { -public: - /// @brief Parses a JSON file and writes country/state/city rows into db. - static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db); + public: + /// @brief Parses a JSON file and writes country/state/city rows into db. + static void LoadWorldCities(const std::string& json_path, + SqliteDatabase& db); }; #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ diff --git a/pipeline/includes/json_handling/stream_parser.h b/pipeline/includes/json_handling/stream_parser.h index f31e984..f712702 100644 --- a/pipeline/includes/json_handling/stream_parser.h +++ b/pipeline/includes/json_handling/stream_parser.h @@ -1,51 +1,52 @@ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ -#include "database/database.h" #include #include +#include "database/database.h" + // Forward declaration to avoid circular dependency class SqliteDatabase; /// @brief In-memory representation of one parsed city entry. struct CityRecord { - int id; - int state_id; - int country_id; - std::string name; - double latitude; - double longitude; + int id; + int state_id; + int country_id; + std::string name; + double latitude; + double longitude; }; /// @brief Streaming SAX parser that emits city records during traversal. class StreamingJsonParser { -public: - /// @brief Parses file_path and invokes callbacks for city rows and progress. - static void Parse(const std::string &file_path, SqliteDatabase &db, - std::function on_city, - std::function on_progress = nullptr); + public: + /// @brief Parses file_path and invokes callbacks for city rows and progress. + static void Parse(const std::string& file_path, SqliteDatabase& db, + std::function on_city, + std::function on_progress = nullptr); -private: - /// @brief Mutable SAX handler state while traversing nested JSON arrays. - struct ParseState { - int current_country_id = 0; - int current_state_id = 0; + private: + /// @brief Mutable SAX handler state while traversing nested JSON arrays. + struct ParseState { + int current_country_id = 0; + int current_state_id = 0; - CityRecord current_city = {}; - bool building_city = false; - std::string current_key; + CityRecord current_city = {}; + bool building_city = false; + std::string current_key; - int array_depth = 0; - int object_depth = 0; - bool in_countries_array = false; - bool in_states_array = false; - bool in_cities_array = false; + int array_depth = 0; + int object_depth = 0; + bool in_countries_array = false; + bool in_states_array = false; + bool in_cities_array = false; - std::function on_city; - std::function on_progress; - size_t bytes_processed = 0; - }; + std::function on_city; + std::function on_progress; + size_t bytes_processed = 0; + }; }; #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ diff --git a/pipeline/includes/web_client/curl_web_client.h b/pipeline/includes/web_client/curl_web_client.h index 21fc20a..ce57000 100644 --- a/pipeline/includes/web_client/curl_web_client.h +++ b/pipeline/includes/web_client/curl_web_client.h @@ -1,29 +1,30 @@ #ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ -#include "web_client/web_client.h" #include +#include "web_client/web_client.h" + // RAII for curl_global_init/cleanup. // An instance of this class should be created in main() before any curl // operations and exist for the lifetime of the application. class CurlGlobalState { -public: - CurlGlobalState(); - ~CurlGlobalState(); - CurlGlobalState(const CurlGlobalState &) = delete; - CurlGlobalState &operator=(const CurlGlobalState &) = delete; + public: + CurlGlobalState(); + ~CurlGlobalState(); + CurlGlobalState(const CurlGlobalState&) = delete; + CurlGlobalState& operator=(const CurlGlobalState&) = delete; }; class CURLWebClient : public WebClient { -public: - CURLWebClient(); - ~CURLWebClient() override; + public: + CURLWebClient(); + ~CURLWebClient() override; - void DownloadToFile(const std::string &url, - const std::string &file_path) override; - std::string Get(const std::string &url) override; - std::string UrlEncode(const std::string &value) override; + void DownloadToFile(const std::string& url, + const std::string& file_path) override; + std::string Get(const std::string& url) override; + std::string UrlEncode(const std::string& value) override; }; #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ diff --git a/pipeline/includes/web_client/web_client.h b/pipeline/includes/web_client/web_client.h index 92344d4..fbd2ca7 100644 --- a/pipeline/includes/web_client/web_client.h +++ b/pipeline/includes/web_client/web_client.h @@ -4,19 +4,19 @@ #include class WebClient { -public: - virtual ~WebClient() = default; + public: + virtual ~WebClient() = default; - // Downloads content from a URL to a file. Throws on error. - virtual void DownloadToFile(const std::string &url, - const std::string &file_path) = 0; + // Downloads content from a URL to a file. Throws on error. + virtual void DownloadToFile(const std::string& url, + const std::string& file_path) = 0; - // Performs a GET request and returns the response body as a string. Throws on - // error. - virtual std::string Get(const std::string &url) = 0; + // Performs a GET request and returns the response body as a string. Throws + // on error. + virtual std::string Get(const std::string& url) = 0; - // URL-encodes a string. - virtual std::string UrlEncode(const std::string &value) = 0; + // URL-encodes a string. + virtual std::string UrlEncode(const std::string& value) = 0; }; #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_ diff --git a/pipeline/includes/wikipedia/wikipedia_service.h b/pipeline/includes/wikipedia/wikipedia_service.h index 0e31345..2507d8a 100644 --- a/pipeline/includes/wikipedia/wikipedia_service.h +++ b/pipeline/includes/wikipedia/wikipedia_service.h @@ -10,18 +10,18 @@ /// @brief Provides cached Wikipedia summary lookups for city and country pairs. class WikipediaService { -public: - /// @brief Creates a new Wikipedia service with the provided web client. - explicit WikipediaService(std::shared_ptr client); + public: + /// @brief Creates a new Wikipedia service with the provided web client. + explicit WikipediaService(std::shared_ptr client); - /// @brief Returns the Wikipedia summary extract for city and country. - [[nodiscard]] std::string GetSummary(std::string_view city, - std::string_view country); + /// @brief Returns the Wikipedia summary extract for city and country. + [[nodiscard]] std::string GetSummary(std::string_view city, + std::string_view country); -private: - std::string FetchExtract(std::string_view query); - std::shared_ptr client_; - std::unordered_map cache_; + private: + std::string FetchExtract(std::string_view query); + std::shared_ptr client_; + std::unordered_map cache_; }; #endif // BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_ diff --git a/pipeline/src/data_generation/data_downloader.cpp b/pipeline/src/data_generation/data_downloader.cpp index 79cbddc..802d90c 100644 --- a/pipeline/src/data_generation/data_downloader.cpp +++ b/pipeline/src/data_generation/data_downloader.cpp @@ -1,46 +1,49 @@ #include "data_generation/data_downloader.h" -#include "web_client/web_client.h" + +#include + #include #include -#include #include #include +#include "web_client/web_client.h" + DataDownloader::DataDownloader(std::shared_ptr web_client) : web_client_(std::move(web_client)) {} DataDownloader::~DataDownloader() {} -bool DataDownloader::FileExists(const std::string &file_path) { - return std::filesystem::exists(file_path); +bool DataDownloader::FileExists(const std::string& file_path) { + return std::filesystem::exists(file_path); } -std::string -DataDownloader::DownloadCountriesDatabase(const std::string &cache_path, - const std::string &commit) { - if (FileExists(cache_path)) { - spdlog::info("[DataDownloader] Cache hit: {}", cache_path); - return cache_path; - } +std::string DataDownloader::DownloadCountriesDatabase( + const std::string& cache_path, const std::string& commit) { + if (FileExists(cache_path)) { + spdlog::info("[DataDownloader] Cache hit: {}", cache_path); + return cache_path; + } - std::string short_commit = commit; - if (commit.length() > 7) { - short_commit = commit.substr(0, 7); - } + std::string short_commit = commit; + if (commit.length() > 7) { + short_commit = commit.substr(0, 7); + } - std::string url = "https://raw.githubusercontent.com/dr5hn/" - "countries-states-cities-database/" + - short_commit + "/json/countries+states+cities.json"; + std::string url = + "https://raw.githubusercontent.com/dr5hn/" + "countries-states-cities-database/" + + short_commit + "/json/countries+states+cities.json"; - spdlog::info("[DataDownloader] Downloading: {}", url); + spdlog::info("[DataDownloader] Downloading: {}", url); - web_client_->DownloadToFile(url, cache_path); + web_client_->DownloadToFile(url, cache_path); - std::ifstream file_check(cache_path, std::ios::binary | std::ios::ate); - std::streamsize size = file_check.tellg(); - file_check.close(); + std::ifstream file_check(cache_path, std::ios::binary | std::ios::ate); + std::streamsize size = file_check.tellg(); + file_check.close(); - spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)", - cache_path, (size / (1024.0 * 1024.0))); - return cache_path; + spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)", + cache_path, (size / (1024.0 * 1024.0))); + return cache_path; } diff --git a/pipeline/src/data_generation/llama/destructor.cpp b/pipeline/src/data_generation/llama/destructor.cpp index 1cdde40..957e071 100644 --- a/pipeline/src/data_generation/llama/destructor.cpp +++ b/pipeline/src/data_generation/llama/destructor.cpp @@ -1,17 +1,16 @@ +#include "data_generation/llama_generator.h" #include "llama.h" -#include "data_generation/llama_generator.h" - LlamaGenerator::~LlamaGenerator() { - if (context_ != nullptr) { - llama_free(context_); - context_ = nullptr; - } + if (context_ != nullptr) { + llama_free(context_); + context_ = nullptr; + } - if (model_ != nullptr) { - llama_model_free(model_); - model_ = nullptr; - } + if (model_ != nullptr) { + llama_model_free(model_); + model_ = nullptr; + } - llama_backend_free(); + llama_backend_free(); } diff --git a/pipeline/src/data_generation/llama/generate_brewery.cpp b/pipeline/src/data_generation/llama/generate_brewery.cpp index ff0b663..3453662 100644 --- a/pipeline/src/data_generation/llama/generate_brewery.cpp +++ b/pipeline/src/data_generation/llama/generate_brewery.cpp @@ -1,72 +1,74 @@ +#include + #include #include -#include - #include "data_generation/llama_generator.h" #include "data_generation/llama_generator_helpers.h" -BreweryResult -LlamaGenerator::GenerateBrewery(const std::string& city_name, - const std::string& country_name, - const std::string& region_context) { - const std::string safe_region_context = - PrepareRegionContextPublic(region_context); +BreweryResult LlamaGenerator::GenerateBrewery( + const std::string& city_name, const std::string& country_name, + const std::string& region_context) { + const std::string safe_region_context = + PrepareRegionContextPublic(region_context); - const std::string system_prompt = - "You are a copywriter for a craft beer travel guide. " - "Your writing is vivid, specific to place, and avoids generic beer " - "cliches. " - "You must output ONLY valid JSON. " - "The JSON schema must be exactly: {\"name\": \"string\", " - "\"description\": \"string\"}. " - "Do not include markdown formatting or backticks."; + const std::string system_prompt = + "You are the brewmaster and owner of a local craft brewery. " + "Write a name and a short, soulful description for your brewery that " + "reflects your pride in the local community and your craft. " + "The tone should be authentic and welcoming, like a note on a " + "chalkboard " + "menu. Output ONLY a single JSON object with keys \"name\" and " + "\"description\". " + "Do not include markdown formatting or backticks."; - std::string prompt = - "Write a brewery name and place-specific description for a craft " - "brewery in " + - city_name + - (country_name.empty() ? std::string("") - : std::string(", ") + country_name) + - (safe_region_context.empty() - ? std::string(".") - : std::string(". Regional context: ") + safe_region_context); + std::string prompt = + "Write a brewery name and place-specific description for a craft " + "brewery in " + + city_name + + (country_name.empty() ? std::string("") + : std::string(", ") + country_name) + + (safe_region_context.empty() + ? std::string(".") + : std::string(". Regional context: ") + safe_region_context); - const int max_attempts = 3; - std::string raw; - std::string last_error; - for (int attempt = 0; attempt < max_attempts; ++attempt) { - raw = Infer(system_prompt, prompt, 384); - spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, - raw); + const int max_attempts = 3; + std::string raw; + std::string last_error; + for (int attempt = 0; attempt < max_attempts; ++attempt) { + raw = Infer(system_prompt, prompt, 384); + spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, + raw); - std::string name; - std::string description; - const std::string validation_error = - ValidateBreweryJsonPublic(raw, name, description); - if (validation_error.empty()) { - return {std::move(name), std::move(description)}; - } + std::string name; + std::string description; + const std::string validation_error = + ValidateBreweryJsonPublic(raw, name, description); + if (validation_error.empty()) { + return {std::move(name), std::move(description)}; + } - last_error = validation_error; - spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", - attempt + 1, validation_error); + last_error = validation_error; + spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", + attempt + 1, validation_error); - prompt = "Your previous response was invalid. Error: " + validation_error + - "\nReturn ONLY valid JSON with this exact schema: " - "{\"name\": \"string\", \"description\": \"string\"}." - "\nDo not include markdown, comments, or extra keys." - "\n\nLocation: " + - city_name + - (country_name.empty() ? std::string("") - : std::string(", ") + country_name) + - (safe_region_context.empty() - ? std::string("") - : std::string("\nRegional context: ") + safe_region_context); - } + prompt = + "Your previous response was invalid. Error: " + validation_error + + "\nReturn ONLY valid JSON with this exact schema: " + "{\"name\": \"string\", \"description\": \"string\"}." + "\nDo not include markdown, comments, or extra keys." + "\n\nLocation: " + + city_name + + (country_name.empty() ? std::string("") + : std::string(", ") + country_name) + + (safe_region_context.empty() + ? std::string("") + : std::string("\nRegional context: ") + safe_region_context); + } - spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: " - "{}", - max_attempts, last_error.empty() ? raw : last_error); - throw std::runtime_error("LlamaGenerator: malformed brewery response"); + spdlog::error( + "LlamaGenerator: malformed brewery response after {} attempts: " + "{}", + max_attempts, last_error.empty() ? raw : last_error); + throw std::runtime_error("LlamaGenerator: malformed brewery response"); } diff --git a/pipeline/src/data_generation/llama/generate_user.cpp b/pipeline/src/data_generation/llama/generate_user.cpp index 4cf8671..22fb57a 100644 --- a/pipeline/src/data_generation/llama/generate_user.cpp +++ b/pipeline/src/data_generation/llama/generate_user.cpp @@ -1,56 +1,57 @@ +#include + #include #include #include -#include - #include "data_generation/llama_generator.h" #include "data_generation/llama_generator_helpers.h" UserResult LlamaGenerator::GenerateUser(const std::string& locale) { - const std::string system_prompt = - "You generate plausible social media profiles for craft beer " - "enthusiasts. " - "Respond with exactly two lines: " - "the first line is a username (lowercase, no spaces, 8-20 characters), " - "the second line is a one-sentence bio (20-40 words). " - "The profile should feel consistent with the locale. " - "No preamble, no labels."; + const std::string system_prompt = + "You generate plausible social media profiles for craft beer " + "enthusiasts. " + "Respond with exactly two lines: " + "the first line is a username (lowercase, no spaces, 8-20 characters), " + "the second line is a one-sentence bio (20-40 words). " + "The profile should feel consistent with the locale. " + "No preamble, no labels."; - std::string prompt = - "Generate a craft beer enthusiast profile. Locale: " + locale; + std::string prompt = + "Generate a craft beer enthusiast profile. Locale: " + locale; - const int max_attempts = 3; - std::string raw; - for (int attempt = 0; attempt < max_attempts; ++attempt) { - raw = Infer(system_prompt, prompt, 128); - spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}", - attempt + 1, raw); + const int max_attempts = 3; + std::string raw; + for (int attempt = 0; attempt < max_attempts; ++attempt) { + raw = Infer(system_prompt, prompt, 128); + spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}", + attempt + 1, raw); - try { - auto [username, bio] = ParseTwoLineResponsePublic( - raw, "LlamaGenerator: malformed user response"); + try { + auto [username, bio] = ParseTwoLineResponsePublic( + raw, "LlamaGenerator: malformed user response"); - username.erase( - std::remove_if(username.begin(), username.end(), - [](unsigned char ch) { return std::isspace(ch); }), - username.end()); + username.erase( + std::remove_if(username.begin(), username.end(), + [](unsigned char ch) { return std::isspace(ch); }), + username.end()); - if (username.empty() || bio.empty()) { - throw std::runtime_error("LlamaGenerator: malformed user response"); + if (username.empty() || bio.empty()) { + throw std::runtime_error("LlamaGenerator: malformed user response"); + } + + if (bio.size() > 200) bio = bio.substr(0, 200); + + return {username, bio}; + } catch (const std::exception& e) { + spdlog::warn( + "LlamaGenerator: malformed user response (attempt {}): {}", + attempt + 1, e.what()); } + } - if (bio.size() > 200) - bio = bio.substr(0, 200); - - return {username, bio}; - } catch (const std::exception &e) { - spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}", - attempt + 1, e.what()); - } - } - - spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}", - max_attempts, raw); - throw std::runtime_error("LlamaGenerator: malformed user response"); + spdlog::error( + "LlamaGenerator: malformed user response after {} attempts: {}", + max_attempts, raw); + throw std::runtime_error("LlamaGenerator: malformed user response"); } diff --git a/pipeline/src/data_generation/llama/helpers.cpp b/pipeline/src/data_generation/llama/helpers.cpp index c1343fc..18a7e8f 100644 --- a/pipeline/src/data_generation/llama/helpers.cpp +++ b/pipeline/src/data_generation/llama/helpers.cpp @@ -1,367 +1,365 @@ #include #include +#include #include #include #include #include #include -#include "llama.h" -#include - #include "data_generation/llama_generator.h" +#include "llama.h" namespace { std::string Trim(std::string value) { - auto not_space = [](unsigned char ch) { return !std::isspace(ch); }; + auto not_space = [](unsigned char ch) { return !std::isspace(ch); }; - value.erase(value.begin(), - std::find_if(value.begin(), value.end(), not_space)); - value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(), - value.end()); + value.erase(value.begin(), + std::find_if(value.begin(), value.end(), not_space)); + value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(), + value.end()); - return value; + return value; } std::string CondenseWhitespace(std::string text) { - std::string out; - out.reserve(text.size()); + std::string out; + out.reserve(text.size()); - bool in_whitespace = false; - for (unsigned char ch : text) { - if (std::isspace(ch)) { - if (!in_whitespace) { - out.push_back(' '); - in_whitespace = true; + bool in_whitespace = false; + for (unsigned char ch : text) { + if (std::isspace(ch)) { + if (!in_whitespace) { + out.push_back(' '); + in_whitespace = true; + } + continue; } - continue; - } - in_whitespace = false; - out.push_back(static_cast(ch)); - } + in_whitespace = false; + out.push_back(static_cast(ch)); + } - return Trim(std::move(out)); + return Trim(std::move(out)); } std::string PrepareRegionContext(std::string_view region_context, std::size_t max_chars) { - std::string normalized = CondenseWhitespace(std::string(region_context)); - if (normalized.size() <= max_chars) { - return normalized; - } + std::string normalized = CondenseWhitespace(std::string(region_context)); + if (normalized.size() <= max_chars) { + return normalized; + } - normalized.resize(max_chars); - const std::size_t last_space = normalized.find_last_of(' '); - if (last_space != std::string::npos && last_space > max_chars / 2) { - normalized.resize(last_space); - } + normalized.resize(max_chars); + const std::size_t last_space = normalized.find_last_of(' '); + if (last_space != std::string::npos && last_space > max_chars / 2) { + normalized.resize(last_space); + } - normalized += "..."; - return normalized; + normalized += "..."; + return normalized; } std::string StripCommonPrefix(std::string line) { - line = Trim(std::move(line)); + line = Trim(std::move(line)); - if (!line.empty() && (line[0] == '-' || line[0] == '*')) { - line = Trim(line.substr(1)); - } else { - std::size_t i = 0; - while (i < line.size() && - std::isdigit(static_cast(line[i]))) { - ++i; - } - if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) { - line = Trim(line.substr(i + 1)); - } - } - - auto strip_label = [&line](const std::string& label) { - if (line.size() >= label.size()) { - bool matches = true; - for (std::size_t i = 0; i < label.size(); ++i) { - if (std::tolower(static_cast(line[i])) != - std::tolower(static_cast(label[i]))) { - matches = false; - break; - } + if (!line.empty() && (line[0] == '-' || line[0] == '*')) { + line = Trim(line.substr(1)); + } else { + std::size_t i = 0; + while (i < line.size() && + std::isdigit(static_cast(line[i]))) { + ++i; } - if (matches) { - line = Trim(line.substr(label.size())); + if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) { + line = Trim(line.substr(i + 1)); } - } - }; + } - strip_label("name:"); - strip_label("brewery name:"); - strip_label("description:"); - strip_label("username:"); - strip_label("bio:"); + auto strip_label = [&line](const std::string& label) { + if (line.size() >= label.size()) { + bool matches = true; + for (std::size_t i = 0; i < label.size(); ++i) { + if (std::tolower(static_cast(line[i])) != + std::tolower(static_cast(label[i]))) { + matches = false; + break; + } + } + if (matches) { + line = Trim(line.substr(label.size())); + } + } + }; - return Trim(std::move(line)); + strip_label("name:"); + strip_label("brewery name:"); + strip_label("description:"); + strip_label("username:"); + strip_label("bio:"); + + return Trim(std::move(line)); } -std::pair -ParseTwoLineResponse(const std::string& raw, const std::string& error_message) { - std::string normalized = raw; - std::replace(normalized.begin(), normalized.end(), '\r', '\n'); +std::pair ParseTwoLineResponse( + const std::string& raw, const std::string& error_message) { + std::string normalized = raw; + std::replace(normalized.begin(), normalized.end(), '\r', '\n'); - std::vector lines; - std::stringstream stream(normalized); - std::string line; - while (std::getline(stream, line)) { - line = StripCommonPrefix(std::move(line)); - if (!line.empty()) - lines.push_back(std::move(line)); - } + std::vector lines; + std::stringstream stream(normalized); + std::string line; + while (std::getline(stream, line)) { + line = StripCommonPrefix(std::move(line)); + if (!line.empty()) lines.push_back(std::move(line)); + } - std::vector filtered; - for (auto &l : lines) { - std::string low = l; - std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { - return static_cast(std::tolower(c)); - }); - if (!l.empty() && l.front() == '<' && low.back() == '>') - continue; - if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) - continue; - filtered.push_back(std::move(l)); - } + std::vector filtered; + for (auto& l : lines) { + std::string low = l; + std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + if (!l.empty() && l.front() == '<' && low.back() == '>') continue; + if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue; + filtered.push_back(std::move(l)); + } - if (filtered.size() < 2) - throw std::runtime_error(error_message); + if (filtered.size() < 2) throw std::runtime_error(error_message); - std::string first = Trim(filtered.front()); - std::string second; - for (size_t i = 1; i < filtered.size(); ++i) { - if (!second.empty()) - second += ' '; - second += filtered[i]; - } - second = Trim(std::move(second)); + std::string first = Trim(filtered.front()); + std::string second; + for (size_t i = 1; i < filtered.size(); ++i) { + if (!second.empty()) second += ' '; + second += filtered[i]; + } + second = Trim(std::move(second)); - if (first.empty() || second.empty()) - throw std::runtime_error(error_message); - return {first, second}; + if (first.empty() || second.empty()) throw std::runtime_error(error_message); + return {first, second}; } -std::string ToChatPrompt(const llama_model *model, +std::string ToChatPrompt(const llama_model* model, const std::string& user_prompt) { - const char *tmpl = llama_model_chat_template(model, nullptr); - if (tmpl == nullptr) { - return user_prompt; - } + const char* tmpl = llama_model_chat_template(model, nullptr); + if (tmpl == nullptr) { + return user_prompt; + } - const llama_chat_message message{"user", user_prompt.c_str()}; + const llama_chat_message message{"user", user_prompt.c_str()}; - std::vector buffer(std::max(1024, user_prompt.size() * 4)); - int32_t required = - llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), - static_cast(buffer.size())); + std::vector buffer( + std::max(1024, user_prompt.size() * 4)); + int32_t required = + llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), + static_cast(buffer.size())); - if (required < 0) { - throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - - if (required >= static_cast(buffer.size())) { - buffer.resize(static_cast(required) + 1); - required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), - static_cast(buffer.size())); - if (required < 0) { + if (required < 0) { throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - } + } - return std::string(buffer.data(), static_cast(required)); + if (required >= static_cast(buffer.size())) { + buffer.resize(static_cast(required) + 1); + required = + llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), + static_cast(buffer.size())); + if (required < 0) { + throw std::runtime_error( + "LlamaGenerator: failed to apply chat template"); + } + } + + return std::string(buffer.data(), static_cast(required)); } -std::string ToChatPrompt(const llama_model *model, +std::string ToChatPrompt(const llama_model* model, const std::string& system_prompt, const std::string& user_prompt) { - const char *tmpl = llama_model_chat_template(model, nullptr); - if (tmpl == nullptr) { - return system_prompt + "\n\n" + user_prompt; - } + const char* tmpl = llama_model_chat_template(model, nullptr); + if (tmpl == nullptr) { + return system_prompt + "\n\n" + user_prompt; + } - const llama_chat_message messages[2] = {{"system", system_prompt.c_str()}, - {"user", user_prompt.c_str()}}; + const llama_chat_message messages[2] = {{"system", system_prompt.c_str()}, + {"user", user_prompt.c_str()}}; - std::vector buffer(std::max( - 1024, (system_prompt.size() + user_prompt.size()) * 4)); - int32_t required = - llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), - static_cast(buffer.size())); + std::vector buffer(std::max( + 1024, (system_prompt.size() + user_prompt.size()) * 4)); + int32_t required = + llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), + static_cast(buffer.size())); - if (required < 0) { - throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - - if (required >= static_cast(buffer.size())) { - buffer.resize(static_cast(required) + 1); - required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), - static_cast(buffer.size())); - if (required < 0) { + if (required < 0) { throw std::runtime_error("LlamaGenerator: failed to apply chat template"); - } - } + } - return std::string(buffer.data(), static_cast(required)); + if (required >= static_cast(buffer.size())) { + buffer.resize(static_cast(required) + 1); + required = + llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), + static_cast(buffer.size())); + if (required < 0) { + throw std::runtime_error( + "LlamaGenerator: failed to apply chat template"); + } + } + + return std::string(buffer.data(), static_cast(required)); } -void AppendTokenPiece(const llama_vocab *vocab, llama_token token, +void AppendTokenPiece(const llama_vocab* vocab, llama_token token, std::string& output) { - std::array buffer{}; - int32_t bytes = - llama_token_to_piece(vocab, token, buffer.data(), - static_cast(buffer.size()), 0, true); + std::array buffer{}; + int32_t bytes = + llama_token_to_piece(vocab, token, buffer.data(), + static_cast(buffer.size()), 0, true); - if (bytes < 0) { - std::vector dynamic_buffer(static_cast(-bytes)); - bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), - static_cast(dynamic_buffer.size()), 0, - true); - if (bytes < 0) { - throw std::runtime_error( - "LlamaGenerator: failed to decode sampled token piece"); - } + if (bytes < 0) { + std::vector dynamic_buffer(static_cast(-bytes)); + bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), + static_cast(dynamic_buffer.size()), + 0, true); + if (bytes < 0) { + throw std::runtime_error( + "LlamaGenerator: failed to decode sampled token piece"); + } - output.append(dynamic_buffer.data(), static_cast(bytes)); - return; - } + output.append(dynamic_buffer.data(), static_cast(bytes)); + return; + } - output.append(buffer.data(), static_cast(bytes)); + output.append(buffer.data(), static_cast(bytes)); } bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) { - std::size_t start = std::string::npos; - int depth = 0; - bool in_string = false; - bool escaped = false; + std::size_t start = std::string::npos; + int depth = 0; + bool in_string = false; + bool escaped = false; - for (std::size_t i = 0; i < text.size(); ++i) { - const char ch = text[i]; + for (std::size_t i = 0; i < text.size(); ++i) { + const char ch = text[i]; - if (in_string) { - if (escaped) { - escaped = false; - } else if (ch == '\\') { - escaped = true; - } else if (ch == '"') { - in_string = false; + if (in_string) { + if (escaped) { + escaped = false; + } else if (ch == '\\') { + escaped = true; + } else if (ch == '"') { + in_string = false; + } + continue; } - continue; - } - if (ch == '"') { - in_string = true; - continue; - } - - if (ch == '{') { - if (depth == 0) { - start = i; + if (ch == '"') { + in_string = true; + continue; } - ++depth; - continue; - } - if (ch == '}') { - if (depth == 0) { - continue; + if (ch == '{') { + if (depth == 0) { + start = i; + } + ++depth; + continue; } - --depth; - if (depth == 0 && start != std::string::npos) { - json_out = text.substr(start, i - start + 1); - return true; - } - } - } - return false; + if (ch == '}') { + if (depth == 0) { + continue; + } + --depth; + if (depth == 0 && start != std::string::npos) { + json_out = text.substr(start, i - start + 1); + return true; + } + } + } + + return false; } std::string ValidateBreweryJson(const std::string& raw, std::string& name_out, std::string& description_out) { - auto validate_object = [&](const boost::json::value& jv, - std::string& error_out) -> bool { - if (!jv.is_object()) { - error_out = "JSON root must be an object"; - return false; - } + auto validate_object = [&](const boost::json::value& jv, + std::string& error_out) -> bool { + if (!jv.is_object()) { + error_out = "JSON root must be an object"; + return false; + } - const auto& obj = jv.get_object(); - if (!obj.contains("name") || !obj.at("name").is_string()) { - error_out = "JSON field 'name' is missing or not a string"; - return false; - } + const auto& obj = jv.get_object(); + if (!obj.contains("name") || !obj.at("name").is_string()) { + error_out = "JSON field 'name' is missing or not a string"; + return false; + } - if (!obj.contains("description") || !obj.at("description").is_string()) { - error_out = "JSON field 'description' is missing or not a string"; - return false; - } + if (!obj.contains("description") || !obj.at("description").is_string()) { + error_out = "JSON field 'description' is missing or not a string"; + return false; + } - name_out = Trim(std::string(obj.at("name").as_string().c_str())); - description_out = - Trim(std::string(obj.at("description").as_string().c_str())); + name_out = Trim(std::string(obj.at("name").as_string().c_str())); + description_out = + Trim(std::string(obj.at("description").as_string().c_str())); - if (name_out.empty()) { - error_out = "JSON field 'name' must not be empty"; - return false; - } + if (name_out.empty()) { + error_out = "JSON field 'name' must not be empty"; + return false; + } - if (description_out.empty()) { - error_out = "JSON field 'description' must not be empty"; - return false; - } + if (description_out.empty()) { + error_out = "JSON field 'description' must not be empty"; + return false; + } - std::string name_lower = name_out; - std::string description_lower = description_out; - std::transform( - name_lower.begin(), name_lower.end(), name_lower.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - std::transform(description_lower.begin(), description_lower.end(), - description_lower.begin(), [](unsigned char c) { - return static_cast(std::tolower(c)); - }); + std::string name_lower = name_out; + std::string description_lower = description_out; + std::transform( + name_lower.begin(), name_lower.end(), name_lower.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + std::transform(description_lower.begin(), description_lower.end(), + description_lower.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); - if (name_lower == "string" || description_lower == "string") { - error_out = "JSON appears to be a schema placeholder, not content"; - return false; - } + if (name_lower == "string" || description_lower == "string") { + error_out = "JSON appears to be a schema placeholder, not content"; + return false; + } - error_out.clear(); - return true; - }; + error_out.clear(); + return true; + }; - boost::system::error_code ec; - boost::json::value jv = boost::json::parse(raw, ec); - std::string validation_error; - if (ec) { - std::string extracted; - if (!ExtractFirstJsonObject(raw, extracted)) { - return "JSON parse error: " + ec.message(); - } + boost::system::error_code ec; + boost::json::value jv = boost::json::parse(raw, ec); + std::string validation_error; + if (ec) { + std::string extracted; + if (!ExtractFirstJsonObject(raw, extracted)) { + return "JSON parse error: " + ec.message(); + } - ec.clear(); - jv = boost::json::parse(extracted, ec); - if (ec) { - return "JSON parse error: " + ec.message(); - } + ec.clear(); + jv = boost::json::parse(extracted, ec); + if (ec) { + return "JSON parse error: " + ec.message(); + } - if (!validate_object(jv, validation_error)) { + if (!validate_object(jv, validation_error)) { + return validation_error; + } + + return {}; + } + + if (!validate_object(jv, validation_error)) { return validation_error; - } + } - return {}; - } - - if (!validate_object(jv, validation_error)) { - return validation_error; - } - - return {}; + return {}; } } // namespace @@ -369,33 +367,32 @@ std::string ValidateBreweryJson(const std::string& raw, std::string& name_out, // Forward declarations for helper functions exposed to other translation units std::string PrepareRegionContextPublic(std::string_view region_context, std::size_t max_chars) { - return PrepareRegionContext(region_context, max_chars); + return PrepareRegionContext(region_context, max_chars); } -std::pair -ParseTwoLineResponsePublic(const std::string& raw, - const std::string& error_message) { - return ParseTwoLineResponse(raw, error_message); +std::pair ParseTwoLineResponsePublic( + const std::string& raw, const std::string& error_message) { + return ParseTwoLineResponse(raw, error_message); } -std::string ToChatPromptPublic(const llama_model *model, +std::string ToChatPromptPublic(const llama_model* model, const std::string& user_prompt) { - return ToChatPrompt(model, user_prompt); + return ToChatPrompt(model, user_prompt); } -std::string ToChatPromptPublic(const llama_model *model, +std::string ToChatPromptPublic(const llama_model* model, const std::string& system_prompt, const std::string& user_prompt) { - return ToChatPrompt(model, system_prompt, user_prompt); + return ToChatPrompt(model, system_prompt, user_prompt); } -void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, +void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token, std::string& output) { - AppendTokenPiece(vocab, token, output); + AppendTokenPiece(vocab, token, output); } std::string ValidateBreweryJsonPublic(const std::string& raw, std::string& name_out, std::string& description_out) { - return ValidateBreweryJson(raw, name_out, description_out); + return ValidateBreweryJson(raw, name_out, description_out); } diff --git a/pipeline/src/data_generation/llama/infer.cpp b/pipeline/src/data_generation/llama/infer.cpp index b3a4e13..c938f87 100644 --- a/pipeline/src/data_generation/llama/infer.cpp +++ b/pipeline/src/data_generation/llama/infer.cpp @@ -1,195 +1,199 @@ +#include + #include #include #include #include #include -#include "llama.h" -#include - #include "data_generation/llama_generator.h" #include "data_generation/llama_generator_helpers.h" +#include "llama.h" std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) { - if (model_ == nullptr || context_ == nullptr) - throw std::runtime_error("LlamaGenerator: model not loaded"); + if (model_ == nullptr || context_ == nullptr) + throw std::runtime_error("LlamaGenerator: model not loaded"); - const llama_vocab *vocab = llama_model_get_vocab(model_); - if (vocab == nullptr) - throw std::runtime_error("LlamaGenerator: vocab unavailable"); + const llama_vocab* vocab = llama_model_get_vocab(model_); + if (vocab == nullptr) + throw std::runtime_error("LlamaGenerator: vocab unavailable"); - llama_memory_clear(llama_get_memory(context_), true); + llama_memory_clear(llama_get_memory(context_), true); - const std::string formatted_prompt = ToChatPromptPublic(model_, prompt); + const std::string formatted_prompt = ToChatPromptPublic(model_, prompt); - std::vector prompt_tokens(formatted_prompt.size() + 8); - int32_t token_count = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), prompt_tokens.data(), - static_cast(prompt_tokens.size()), true, true); + std::vector prompt_tokens(formatted_prompt.size() + 8); + int32_t token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); - if (token_count < 0) { - prompt_tokens.resize(static_cast(-token_count)); - token_count = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), prompt_tokens.data(), - static_cast(prompt_tokens.size()), true, true); - } + if (token_count < 0) { + prompt_tokens.resize(static_cast(-token_count)); + token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); + } - if (token_count < 0) - throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + if (token_count < 0) + throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); - const int32_t n_ctx = static_cast(llama_n_ctx(context_)); - const int32_t n_batch = static_cast(llama_n_batch(context_)); - if (n_ctx <= 1 || n_batch <= 0) { - throw std::runtime_error("LlamaGenerator: invalid context or batch size"); - } + const int32_t n_ctx = static_cast(llama_n_ctx(context_)); + const int32_t n_batch = static_cast(llama_n_batch(context_)); + if (n_ctx <= 1 || n_batch <= 0) { + throw std::runtime_error("LlamaGenerator: invalid context or batch size"); + } - const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); - int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); - prompt_budget = std::max(1, prompt_budget); + const int32_t effective_max_tokens = + std::max(1, std::min(max_tokens, n_ctx - 1)); + int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); + prompt_budget = std::max(1, prompt_budget); - prompt_tokens.resize(static_cast(token_count)); - if (token_count > prompt_budget) { - spdlog::warn( - "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " - "to fit n_batch/n_ctx limits", - token_count, prompt_budget); - prompt_tokens.resize(static_cast(prompt_budget)); - token_count = prompt_budget; - } + prompt_tokens.resize(static_cast(token_count)); + if (token_count > prompt_budget) { + spdlog::warn( + "LlamaGenerator: prompt too long ({} tokens), truncating to {} " + "tokens " + "to fit n_batch/n_ctx limits", + token_count, prompt_budget); + prompt_tokens.resize(static_cast(prompt_budget)); + token_count = prompt_budget; + } - const llama_batch prompt_batch = llama_batch_get_one( - prompt_tokens.data(), static_cast(prompt_tokens.size())); - if (llama_decode(context_, prompt_batch) != 0) - throw std::runtime_error("LlamaGenerator: prompt decode failed"); + const llama_batch prompt_batch = llama_batch_get_one( + prompt_tokens.data(), static_cast(prompt_tokens.size())); + if (llama_decode(context_, prompt_batch) != 0) + throw std::runtime_error("LlamaGenerator: prompt decode failed"); - llama_sampler_chain_params sampler_params = - llama_sampler_chain_default_params(); - using SamplerPtr = - std::unique_ptr; - SamplerPtr sampler(llama_sampler_chain_init(sampler_params), - &llama_sampler_free); - if (!sampler) - throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + llama_sampler_chain_params sampler_params = + llama_sampler_chain_default_params(); + using SamplerPtr = + std::unique_ptr; + SamplerPtr sampler(llama_sampler_chain_init(sampler_params), + &llama_sampler_free); + if (!sampler) + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_temp(sampling_temperature_)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_top_p(sampling_top_p_, 1)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_dist(sampling_seed_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_temp(sampling_temperature_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_top_p(sampling_top_p_, 1)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_dist(sampling_seed_)); - std::vector generated_tokens; - generated_tokens.reserve(static_cast(max_tokens)); + std::vector generated_tokens; + generated_tokens.reserve(static_cast(max_tokens)); - for (int i = 0; i < effective_max_tokens; ++i) { - const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); - if (llama_vocab_is_eog(vocab, next)) - break; - generated_tokens.push_back(next); - llama_token token = next; - const llama_batch one_token_batch = llama_batch_get_one(&token, 1); - if (llama_decode(context_, one_token_batch) != 0) - throw std::runtime_error( - "LlamaGenerator: decode failed during generation"); - } + for (int i = 0; i < effective_max_tokens; ++i) { + const llama_token next = + llama_sampler_sample(sampler.get(), context_, -1); + if (llama_vocab_is_eog(vocab, next)) break; + generated_tokens.push_back(next); + llama_token token = next; + const llama_batch one_token_batch = llama_batch_get_one(&token, 1); + if (llama_decode(context_, one_token_batch) != 0) + throw std::runtime_error( + "LlamaGenerator: decode failed during generation"); + } - std::string output; - for (const llama_token token : generated_tokens) - AppendTokenPiecePublic(vocab, token, output); - return output; + std::string output; + for (const llama_token token : generated_tokens) + AppendTokenPiecePublic(vocab, token, output); + return output; } std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, int max_tokens) { - if (model_ == nullptr || context_ == nullptr) - throw std::runtime_error("LlamaGenerator: model not loaded"); + if (model_ == nullptr || context_ == nullptr) + throw std::runtime_error("LlamaGenerator: model not loaded"); - const llama_vocab *vocab = llama_model_get_vocab(model_); - if (vocab == nullptr) - throw std::runtime_error("LlamaGenerator: vocab unavailable"); + const llama_vocab* vocab = llama_model_get_vocab(model_); + if (vocab == nullptr) + throw std::runtime_error("LlamaGenerator: vocab unavailable"); - llama_memory_clear(llama_get_memory(context_), true); + llama_memory_clear(llama_get_memory(context_), true); - const std::string formatted_prompt = - ToChatPromptPublic(model_, system_prompt, prompt); + const std::string formatted_prompt = + ToChatPromptPublic(model_, system_prompt, prompt); - std::vector prompt_tokens(formatted_prompt.size() + 8); - int32_t token_count = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), prompt_tokens.data(), - static_cast(prompt_tokens.size()), true, true); + std::vector prompt_tokens(formatted_prompt.size() + 8); + int32_t token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); - if (token_count < 0) { - prompt_tokens.resize(static_cast(-token_count)); - token_count = llama_tokenize( - vocab, formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), prompt_tokens.data(), - static_cast(prompt_tokens.size()), true, true); - } + if (token_count < 0) { + prompt_tokens.resize(static_cast(-token_count)); + token_count = llama_tokenize( + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); + } - if (token_count < 0) - throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + if (token_count < 0) + throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); - const int32_t n_ctx = static_cast(llama_n_ctx(context_)); - const int32_t n_batch = static_cast(llama_n_batch(context_)); - if (n_ctx <= 1 || n_batch <= 0) { - throw std::runtime_error("LlamaGenerator: invalid context or batch size"); - } + const int32_t n_ctx = static_cast(llama_n_ctx(context_)); + const int32_t n_batch = static_cast(llama_n_batch(context_)); + if (n_ctx <= 1 || n_batch <= 0) { + throw std::runtime_error("LlamaGenerator: invalid context or batch size"); + } - const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); - int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); - prompt_budget = std::max(1, prompt_budget); + const int32_t effective_max_tokens = + std::max(1, std::min(max_tokens, n_ctx - 1)); + int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); + prompt_budget = std::max(1, prompt_budget); - prompt_tokens.resize(static_cast(token_count)); - if (token_count > prompt_budget) { - spdlog::warn( - "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " - "to fit n_batch/n_ctx limits", - token_count, prompt_budget); - prompt_tokens.resize(static_cast(prompt_budget)); - token_count = prompt_budget; - } + prompt_tokens.resize(static_cast(token_count)); + if (token_count > prompt_budget) { + spdlog::warn( + "LlamaGenerator: prompt too long ({} tokens), truncating to {} " + "tokens " + "to fit n_batch/n_ctx limits", + token_count, prompt_budget); + prompt_tokens.resize(static_cast(prompt_budget)); + token_count = prompt_budget; + } - const llama_batch prompt_batch = llama_batch_get_one( - prompt_tokens.data(), static_cast(prompt_tokens.size())); - if (llama_decode(context_, prompt_batch) != 0) - throw std::runtime_error("LlamaGenerator: prompt decode failed"); + const llama_batch prompt_batch = llama_batch_get_one( + prompt_tokens.data(), static_cast(prompt_tokens.size())); + if (llama_decode(context_, prompt_batch) != 0) + throw std::runtime_error("LlamaGenerator: prompt decode failed"); - llama_sampler_chain_params sampler_params = - llama_sampler_chain_default_params(); - using SamplerPtr = - std::unique_ptr; - SamplerPtr sampler(llama_sampler_chain_init(sampler_params), - &llama_sampler_free); - if (!sampler) - throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + llama_sampler_chain_params sampler_params = + llama_sampler_chain_default_params(); + using SamplerPtr = + std::unique_ptr; + SamplerPtr sampler(llama_sampler_chain_init(sampler_params), + &llama_sampler_free); + if (!sampler) + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_temp(sampling_temperature_)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_top_p(sampling_top_p_, 1)); - llama_sampler_chain_add(sampler.get(), - llama_sampler_init_dist(sampling_seed_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_temp(sampling_temperature_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_top_p(sampling_top_p_, 1)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_dist(sampling_seed_)); - std::vector generated_tokens; - generated_tokens.reserve(static_cast(max_tokens)); + std::vector generated_tokens; + generated_tokens.reserve(static_cast(max_tokens)); - for (int i = 0; i < effective_max_tokens; ++i) { - const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); - if (llama_vocab_is_eog(vocab, next)) - break; - generated_tokens.push_back(next); - llama_token token = next; - const llama_batch one_token_batch = llama_batch_get_one(&token, 1); - if (llama_decode(context_, one_token_batch) != 0) - throw std::runtime_error( - "LlamaGenerator: decode failed during generation"); - } + for (int i = 0; i < effective_max_tokens; ++i) { + const llama_token next = + llama_sampler_sample(sampler.get(), context_, -1); + if (llama_vocab_is_eog(vocab, next)) break; + generated_tokens.push_back(next); + llama_token token = next; + const llama_batch one_token_batch = llama_batch_get_one(&token, 1); + if (llama_decode(context_, one_token_batch) != 0) + throw std::runtime_error( + "LlamaGenerator: decode failed during generation"); + } - std::string output; - for (const llama_token token : generated_tokens) - AppendTokenPiecePublic(vocab, token, output); - return output; + std::string output; + for (const llama_token token : generated_tokens) + AppendTokenPiecePublic(vocab, token, output); + return output; } diff --git a/pipeline/src/data_generation/llama/load.cpp b/pipeline/src/data_generation/llama/load.cpp index c38808b..1853b5f 100644 --- a/pipeline/src/data_generation/llama/load.cpp +++ b/pipeline/src/data_generation/llama/load.cpp @@ -1,42 +1,42 @@ +#include + #include #include -#include "llama.h" -#include - #include "data_generation/llama_generator.h" +#include "llama.h" void LlamaGenerator::Load(const std::string& model_path) { - if (model_path.empty()) - throw std::runtime_error("LlamaGenerator: model path must not be empty"); + if (model_path.empty()) + throw std::runtime_error("LlamaGenerator: model path must not be empty"); - if (context_ != nullptr) { - llama_free(context_); - context_ = nullptr; - } - if (model_ != nullptr) { - llama_model_free(model_); - model_ = nullptr; - } + if (context_ != nullptr) { + llama_free(context_); + context_ = nullptr; + } + if (model_ != nullptr) { + llama_model_free(model_); + model_ = nullptr; + } - llama_backend_init(); + llama_backend_init(); - llama_model_params model_params = llama_model_default_params(); - model_ = llama_model_load_from_file(model_path.c_str(), model_params); - if (model_ == nullptr) { - throw std::runtime_error( - "LlamaGenerator: failed to load model from path: " + model_path); - } + llama_model_params model_params = llama_model_default_params(); + model_ = llama_model_load_from_file(model_path.c_str(), model_params); + if (model_ == nullptr) { + throw std::runtime_error( + "LlamaGenerator: failed to load model from path: " + model_path); + } - llama_context_params context_params = llama_context_default_params(); - context_params.n_ctx = 2048; + llama_context_params context_params = llama_context_default_params(); + context_params.n_ctx = 2048; - context_ = llama_init_from_model(model_, context_params); - if (context_ == nullptr) { - llama_model_free(model_); - model_ = nullptr; - throw std::runtime_error("LlamaGenerator: failed to create context"); - } + context_ = llama_init_from_model(model_, context_params); + if (context_ == nullptr) { + llama_model_free(model_); + model_ = nullptr; + throw std::runtime_error("LlamaGenerator: failed to create context"); + } - spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); + spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); } diff --git a/pipeline/src/data_generation/llama/set_sampling_options.cpp b/pipeline/src/data_generation/llama/set_sampling_options.cpp index 3898fca..8953eda 100644 --- a/pipeline/src/data_generation/llama/set_sampling_options.cpp +++ b/pipeline/src/data_generation/llama/set_sampling_options.cpp @@ -1,26 +1,25 @@ #include -#include "llama.h" - #include "data_generation/llama_generator.h" +#include "llama.h" void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, int seed) { - if (temperature < 0.0f) { - throw std::runtime_error( - "LlamaGenerator: sampling temperature must be >= 0"); - } - if (!(top_p > 0.0f && top_p <= 1.0f)) { - throw std::runtime_error( - "LlamaGenerator: sampling top-p must be in (0, 1]"); - } - if (seed < -1) { - throw std::runtime_error( - "LlamaGenerator: seed must be >= 0, or -1 for random"); - } + if (temperature < 0.0f) { + throw std::runtime_error( + "LlamaGenerator: sampling temperature must be >= 0"); + } + if (!(top_p > 0.0f && top_p <= 1.0f)) { + throw std::runtime_error( + "LlamaGenerator: sampling top-p must be in (0, 1]"); + } + if (seed < -1) { + throw std::runtime_error( + "LlamaGenerator: seed must be >= 0, or -1 for random"); + } - sampling_temperature_ = temperature; - sampling_top_p_ = top_p; - sampling_seed_ = (seed < 0) ? static_cast(LLAMA_DEFAULT_SEED) - : static_cast(seed); + sampling_temperature_ = temperature; + sampling_top_p_ = top_p; + sampling_seed_ = (seed < 0) ? static_cast(LLAMA_DEFAULT_SEED) + : static_cast(seed); } diff --git a/pipeline/src/data_generation/mock_generator.cpp b/pipeline/src/data_generation/mock/data.cpp similarity index 69% rename from pipeline/src/data_generation/mock_generator.cpp rename to pipeline/src/data_generation/mock/data.cpp index 4623051..98637e4 100644 --- a/pipeline/src/data_generation/mock_generator.cpp +++ b/pipeline/src/data_generation/mock/data.cpp @@ -1,7 +1,7 @@ -#include "data_generation/mock_generator.h" +#include +#include -#include -#include +#include "data_generation/mock_generator.h" const std::vector MockGenerator::kBreweryAdjectives = { "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", @@ -63,42 +63,3 @@ const std::vector MockGenerator::kBios = { "Craft beer fan mapping tasting notes and favorite brew routes.", "Always ready to trade recommendations for underrated local breweries.", "Keeping a running list of must-try collab releases and tap takeovers."}; - -void MockGenerator::Load(const std::string & /*modelPath*/) { - spdlog::info("[MockGenerator] No model needed"); -} - -std::size_t MockGenerator::DeterministicHash(const std::string &a, - const std::string &b) { - std::size_t seed = std::hash{}(a); - const std::size_t mixed = std::hash{}(b); - seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); - seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13)); - return seed; -} - -BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name, - const std::string &country_name, - const std::string ®ion_context) { - const std::string location_key = - country_name.empty() ? city_name : city_name + "," + country_name; - const std::size_t hash = region_context.empty() - ? std::hash{}(location_key) - : DeterministicHash(location_key, region_context); - - BreweryResult result; - result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " + - kBreweryNouns[(hash / 7) % kBreweryNouns.size()]; - result.description = - kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()]; - return result; -} - -UserResult MockGenerator::GenerateUser(const std::string &locale) { - const std::size_t hash = std::hash{}(locale); - - UserResult result; - result.username = kUsernames[hash % kUsernames.size()]; - result.bio = kBios[(hash / 11) % kBios.size()]; - return result; -} diff --git a/pipeline/src/data_generation/mock/deterministic_hash.cpp b/pipeline/src/data_generation/mock/deterministic_hash.cpp new file mode 100644 index 0000000..e59c359 --- /dev/null +++ b/pipeline/src/data_generation/mock/deterministic_hash.cpp @@ -0,0 +1,12 @@ +#include + +#include "data_generation/mock_generator.h" + +std::size_t MockGenerator::DeterministicHash(const std::string& a, + const std::string& b) { + std::size_t seed = std::hash{}(a); + const std::size_t mixed = std::hash{}(b); + seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13)); + return seed; +} diff --git a/pipeline/src/data_generation/mock/generate_brewery.cpp b/pipeline/src/data_generation/mock/generate_brewery.cpp new file mode 100644 index 0000000..0f7b611 --- /dev/null +++ b/pipeline/src/data_generation/mock/generate_brewery.cpp @@ -0,0 +1,21 @@ +#include +#include + +#include "data_generation/mock_generator.h" + +BreweryResult MockGenerator::GenerateBrewery( + const std::string& city_name, const std::string& country_name, + const std::string& region_context) { + const std::string location_key = + country_name.empty() ? city_name : city_name + "," + country_name; + const std::size_t hash = + region_context.empty() ? std::hash{}(location_key) + : DeterministicHash(location_key, region_context); + + BreweryResult result; + result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " + + kBreweryNouns[(hash / 7) % kBreweryNouns.size()]; + result.description = + kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()]; + return result; +} diff --git a/pipeline/src/data_generation/mock/generate_user.cpp b/pipeline/src/data_generation/mock/generate_user.cpp new file mode 100644 index 0000000..1f46baa --- /dev/null +++ b/pipeline/src/data_generation/mock/generate_user.cpp @@ -0,0 +1,13 @@ +#include +#include + +#include "data_generation/mock_generator.h" + +UserResult MockGenerator::GenerateUser(const std::string& locale) { + const std::size_t hash = std::hash{}(locale); + + UserResult result; + result.username = kUsernames[hash % kUsernames.size()]; + result.bio = kBios[(hash / 11) % kBios.size()]; + return result; +} diff --git a/pipeline/src/data_generation/mock/load.cpp b/pipeline/src/data_generation/mock/load.cpp new file mode 100644 index 0000000..6d6d99b --- /dev/null +++ b/pipeline/src/data_generation/mock/load.cpp @@ -0,0 +1,9 @@ +#include + +#include + +#include "data_generation/mock_generator.h" + +void MockGenerator::Load(const std::string& /*modelPath*/) { + spdlog::info("[MockGenerator] No model needed"); +} diff --git a/pipeline/src/database/database.cpp b/pipeline/src/database/database.cpp index 98c3867..6242a22 100644 --- a/pipeline/src/database/database.cpp +++ b/pipeline/src/database/database.cpp @@ -1,11 +1,13 @@ #include "database/database.h" + #include + #include void SqliteDatabase::InitializeSchema() { - std::lock_guard lock(db_mutex_); + std::lock_guard lock(db_mutex_); - const char *schema = R"( + const char* schema = R"( CREATE TABLE IF NOT EXISTS countries ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, @@ -33,218 +35,219 @@ void SqliteDatabase::InitializeSchema() { ); )"; - char *errMsg = nullptr; - int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg); - if (rc != SQLITE_OK) { - std::string error = errMsg ? std::string(errMsg) : "Unknown error"; - sqlite3_free(errMsg); - throw std::runtime_error("Failed to create schema: " + error); - } + char* errMsg = nullptr; + int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg); + if (rc != SQLITE_OK) { + std::string error = errMsg ? std::string(errMsg) : "Unknown error"; + sqlite3_free(errMsg); + throw std::runtime_error("Failed to create schema: " + error); + } } SqliteDatabase::~SqliteDatabase() { - if (db_) { - sqlite3_close(db_); - } + if (db_) { + sqlite3_close(db_); + } } -void SqliteDatabase::Initialize(const std::string &db_path) { - int rc = sqlite3_open(db_path.c_str(), &db_); - if (rc) { - throw std::runtime_error("Failed to open SQLite database: " + db_path); - } - spdlog::info("OK: SQLite database opened: {}", db_path); - InitializeSchema(); +void SqliteDatabase::Initialize(const std::string& db_path) { + int rc = sqlite3_open(db_path.c_str(), &db_); + if (rc) { + throw std::runtime_error("Failed to open SQLite database: " + db_path); + } + spdlog::info("OK: SQLite database opened: {}", db_path); + InitializeSchema(); } void SqliteDatabase::BeginTransaction() { - std::lock_guard lock(db_mutex_); - char *err = nullptr; - if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) != - SQLITE_OK) { - std::string msg = err ? err : "unknown"; - sqlite3_free(err); - throw std::runtime_error("BeginTransaction failed: " + msg); - } + std::lock_guard lock(db_mutex_); + char* err = nullptr; + if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) != + SQLITE_OK) { + std::string msg = err ? err : "unknown"; + sqlite3_free(err); + throw std::runtime_error("BeginTransaction failed: " + msg); + } } void SqliteDatabase::CommitTransaction() { - std::lock_guard lock(db_mutex_); - char *err = nullptr; - if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) { - std::string msg = err ? err : "unknown"; - sqlite3_free(err); - throw std::runtime_error("CommitTransaction failed: " + msg); - } + std::lock_guard lock(db_mutex_); + char* err = nullptr; + if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) { + std::string msg = err ? err : "unknown"; + sqlite3_free(err); + throw std::runtime_error("CommitTransaction failed: " + msg); + } } -void SqliteDatabase::InsertCountry(int id, const std::string &name, - const std::string &iso2, - const std::string &iso3) { - std::lock_guard lock(db_mutex_); +void SqliteDatabase::InsertCountry(int id, const std::string& name, + const std::string& iso2, + const std::string& iso3) { + std::lock_guard lock(db_mutex_); - const char *query = R"( + const char* query = R"( INSERT OR IGNORE INTO countries (id, name, iso2, iso3) VALUES (?, ?, ?, ?) )"; - sqlite3_stmt *stmt; - int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); - if (rc != SQLITE_OK) - throw std::runtime_error("Failed to prepare country insert"); + sqlite3_stmt* stmt; + int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); + if (rc != SQLITE_OK) + throw std::runtime_error("Failed to prepare country insert"); - sqlite3_bind_int(stmt, 1, id); - sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_int(stmt, 1, id); + sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC); - if (sqlite3_step(stmt) != SQLITE_DONE) { - throw std::runtime_error("Failed to insert country"); - } - sqlite3_finalize(stmt); + if (sqlite3_step(stmt) != SQLITE_DONE) { + throw std::runtime_error("Failed to insert country"); + } + sqlite3_finalize(stmt); } -void SqliteDatabase::InsertState(int id, int country_id, const std::string &name, - const std::string &iso2) { - std::lock_guard lock(db_mutex_); +void SqliteDatabase::InsertState(int id, int country_id, + const std::string& name, + const std::string& iso2) { + std::lock_guard lock(db_mutex_); - const char *query = R"( + const char* query = R"( INSERT OR IGNORE INTO states (id, country_id, name, iso2) VALUES (?, ?, ?, ?) )"; - sqlite3_stmt *stmt; - int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); - if (rc != SQLITE_OK) - throw std::runtime_error("Failed to prepare state insert"); + sqlite3_stmt* stmt; + int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); + if (rc != SQLITE_OK) + throw std::runtime_error("Failed to prepare state insert"); - sqlite3_bind_int(stmt, 1, id); - sqlite3_bind_int(stmt, 2, country_id); - sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_int(stmt, 1, id); + sqlite3_bind_int(stmt, 2, country_id); + sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC); - if (sqlite3_step(stmt) != SQLITE_DONE) { - throw std::runtime_error("Failed to insert state"); - } - sqlite3_finalize(stmt); + if (sqlite3_step(stmt) != SQLITE_DONE) { + throw std::runtime_error("Failed to insert state"); + } + sqlite3_finalize(stmt); } void SqliteDatabase::InsertCity(int id, int state_id, int country_id, - const std::string &name, double latitude, + const std::string& name, double latitude, double longitude) { - std::lock_guard lock(db_mutex_); + std::lock_guard lock(db_mutex_); - const char *query = R"( + const char* query = R"( INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude) VALUES (?, ?, ?, ?, ?, ?) )"; - sqlite3_stmt *stmt; - int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); - if (rc != SQLITE_OK) - throw std::runtime_error("Failed to prepare city insert"); + sqlite3_stmt* stmt; + int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); + if (rc != SQLITE_OK) + throw std::runtime_error("Failed to prepare city insert"); - sqlite3_bind_int(stmt, 1, id); - sqlite3_bind_int(stmt, 2, state_id); - sqlite3_bind_int(stmt, 3, country_id); - sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_double(stmt, 5, latitude); - sqlite3_bind_double(stmt, 6, longitude); + sqlite3_bind_int(stmt, 1, id); + sqlite3_bind_int(stmt, 2, state_id); + sqlite3_bind_int(stmt, 3, country_id); + sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC); + sqlite3_bind_double(stmt, 5, latitude); + sqlite3_bind_double(stmt, 6, longitude); - if (sqlite3_step(stmt) != SQLITE_DONE) { - throw std::runtime_error("Failed to insert city"); - } - sqlite3_finalize(stmt); + if (sqlite3_step(stmt) != SQLITE_DONE) { + throw std::runtime_error("Failed to insert city"); + } + sqlite3_finalize(stmt); } std::vector SqliteDatabase::QueryCities() { - std::lock_guard lock(db_mutex_); - std::vector cities; - sqlite3_stmt *stmt = nullptr; + std::lock_guard lock(db_mutex_); + std::vector cities; + sqlite3_stmt* stmt = nullptr; - const char *query = "SELECT id, name, country_id FROM cities ORDER BY name"; - int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); + const char* query = "SELECT id, name, country_id FROM cities ORDER BY name"; + int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); - if (rc != SQLITE_OK) { - throw std::runtime_error("Failed to prepare query"); - } + if (rc != SQLITE_OK) { + throw std::runtime_error("Failed to prepare query"); + } - while (sqlite3_step(stmt) == SQLITE_ROW) { - int id = sqlite3_column_int(stmt, 0); - const char *name = - reinterpret_cast(sqlite3_column_text(stmt, 1)); - int country_id = sqlite3_column_int(stmt, 2); - cities.push_back({id, name ? std::string(name) : "", country_id}); - } + while (sqlite3_step(stmt) == SQLITE_ROW) { + int id = sqlite3_column_int(stmt, 0); + const char* name = + reinterpret_cast(sqlite3_column_text(stmt, 1)); + int country_id = sqlite3_column_int(stmt, 2); + cities.push_back({id, name ? std::string(name) : "", country_id}); + } - sqlite3_finalize(stmt); - return cities; + sqlite3_finalize(stmt); + return cities; } std::vector SqliteDatabase::QueryCountries(int limit) { - std::lock_guard lock(db_mutex_); + std::lock_guard lock(db_mutex_); - std::vector countries; - sqlite3_stmt *stmt = nullptr; + std::vector countries; + sqlite3_stmt* stmt = nullptr; - std::string query = - "SELECT id, name, iso2, iso3 FROM countries ORDER BY name"; - if (limit > 0) { - query += " LIMIT " + std::to_string(limit); - } + std::string query = + "SELECT id, name, iso2, iso3 FROM countries ORDER BY name"; + if (limit > 0) { + query += " LIMIT " + std::to_string(limit); + } - int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); + int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); - if (rc != SQLITE_OK) { - throw std::runtime_error("Failed to prepare countries query"); - } + if (rc != SQLITE_OK) { + throw std::runtime_error("Failed to prepare countries query"); + } - while (sqlite3_step(stmt) == SQLITE_ROW) { - int id = sqlite3_column_int(stmt, 0); - const char *name = - reinterpret_cast(sqlite3_column_text(stmt, 1)); - const char *iso2 = - reinterpret_cast(sqlite3_column_text(stmt, 2)); - const char *iso3 = - reinterpret_cast(sqlite3_column_text(stmt, 3)); - countries.push_back({id, name ? std::string(name) : "", - iso2 ? std::string(iso2) : "", - iso3 ? std::string(iso3) : ""}); - } + while (sqlite3_step(stmt) == SQLITE_ROW) { + int id = sqlite3_column_int(stmt, 0); + const char* name = + reinterpret_cast(sqlite3_column_text(stmt, 1)); + const char* iso2 = + reinterpret_cast(sqlite3_column_text(stmt, 2)); + const char* iso3 = + reinterpret_cast(sqlite3_column_text(stmt, 3)); + countries.push_back({id, name ? std::string(name) : "", + iso2 ? std::string(iso2) : "", + iso3 ? std::string(iso3) : ""}); + } - sqlite3_finalize(stmt); - return countries; + sqlite3_finalize(stmt); + return countries; } std::vector SqliteDatabase::QueryStates(int limit) { - std::lock_guard lock(db_mutex_); + std::lock_guard lock(db_mutex_); - std::vector states; - sqlite3_stmt *stmt = nullptr; + std::vector states; + sqlite3_stmt* stmt = nullptr; - std::string query = - "SELECT id, name, iso2, country_id FROM states ORDER BY name"; - if (limit > 0) { - query += " LIMIT " + std::to_string(limit); - } + std::string query = + "SELECT id, name, iso2, country_id FROM states ORDER BY name"; + if (limit > 0) { + query += " LIMIT " + std::to_string(limit); + } - int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); + int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); - if (rc != SQLITE_OK) { - throw std::runtime_error("Failed to prepare states query"); - } + if (rc != SQLITE_OK) { + throw std::runtime_error("Failed to prepare states query"); + } - while (sqlite3_step(stmt) == SQLITE_ROW) { - int id = sqlite3_column_int(stmt, 0); - const char *name = - reinterpret_cast(sqlite3_column_text(stmt, 1)); - const char *iso2 = - reinterpret_cast(sqlite3_column_text(stmt, 2)); - int country_id = sqlite3_column_int(stmt, 3); - states.push_back({id, name ? std::string(name) : "", - iso2 ? std::string(iso2) : "", country_id}); - } + while (sqlite3_step(stmt) == SQLITE_ROW) { + int id = sqlite3_column_int(stmt, 0); + const char* name = + reinterpret_cast(sqlite3_column_text(stmt, 1)); + const char* iso2 = + reinterpret_cast(sqlite3_column_text(stmt, 2)); + int country_id = sqlite3_column_int(stmt, 3); + states.push_back({id, name ? std::string(name) : "", + iso2 ? std::string(iso2) : "", country_id}); + } - sqlite3_finalize(stmt); - return states; + sqlite3_finalize(stmt); + return states; } diff --git a/pipeline/src/json_handling/json_loader.cpp b/pipeline/src/json_handling/json_loader.cpp index bfc8d12..71875a1 100644 --- a/pipeline/src/json_handling/json_loader.cpp +++ b/pipeline/src/json_handling/json_loader.cpp @@ -1,65 +1,66 @@ -#include +#include "json_handling/json_loader.h" #include -#include "json_handling/json_loader.h" +#include + #include "json_handling/stream_parser.h" -void JsonLoader::LoadWorldCities(const std::string &json_path, - SqliteDatabase &db) { - constexpr size_t kBatchSize = 10000; +void JsonLoader::LoadWorldCities(const std::string& json_path, + SqliteDatabase& db) { + constexpr size_t kBatchSize = 10000; - auto startTime = std::chrono::high_resolution_clock::now(); - spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path); + auto startTime = std::chrono::high_resolution_clock::now(); + spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path); - db.BeginTransaction(); - bool transactionOpen = true; + db.BeginTransaction(); + bool transactionOpen = true; - size_t citiesProcessed = 0; - try { - StreamingJsonParser::Parse( - json_path, db, - [&](const CityRecord &record) { - db.InsertCity(record.id, record.state_id, record.country_id, - record.name, record.latitude, record.longitude); - ++citiesProcessed; + size_t citiesProcessed = 0; + try { + StreamingJsonParser::Parse( + json_path, db, + [&](const CityRecord& record) { + db.InsertCity(record.id, record.state_id, record.country_id, + record.name, record.latitude, record.longitude); + ++citiesProcessed; - if (citiesProcessed % kBatchSize == 0) { - db.CommitTransaction(); - db.BeginTransaction(); - } - }, - [&](size_t current, size_t /*total*/) { - if (current % kBatchSize == 0 && current > 0) { - spdlog::info(" [Progress] Parsed {} cities...", current); - } - }); + if (citiesProcessed % kBatchSize == 0) { + db.CommitTransaction(); + db.BeginTransaction(); + } + }, + [&](size_t current, size_t /*total*/) { + if (current % kBatchSize == 0 && current > 0) { + spdlog::info(" [Progress] Parsed {} cities...", current); + } + }); - spdlog::info(" OK: Parsed all cities from JSON"); + spdlog::info(" OK: Parsed all cities from JSON"); - if (transactionOpen) { - db.CommitTransaction(); - transactionOpen = false; - } - } catch (...) { - if (transactionOpen) { - db.CommitTransaction(); - } - throw; - } + if (transactionOpen) { + db.CommitTransaction(); + transactionOpen = false; + } + } catch (...) { + if (transactionOpen) { + db.CommitTransaction(); + } + throw; + } - auto endTime = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast( - endTime - startTime); + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + endTime - startTime); - spdlog::info("\n=== World City Data Loading Summary ===\n"); - spdlog::info("Cities inserted: {}", citiesProcessed); - spdlog::info("Elapsed time: {} ms", duration.count()); - long long throughput = - (citiesProcessed > 0 && duration.count() > 0) - ? (1000LL * static_cast(citiesProcessed)) / - static_cast(duration.count()) - : 0LL; - spdlog::info("Throughput: {} cities/sec", throughput); - spdlog::info("=======================================\n"); + spdlog::info("\n=== World City Data Loading Summary ===\n"); + spdlog::info("Cities inserted: {}", citiesProcessed); + spdlog::info("Elapsed time: {} ms", duration.count()); + long long throughput = + (citiesProcessed > 0 && duration.count() > 0) + ? (1000LL * static_cast(citiesProcessed)) / + static_cast(duration.count()) + : 0LL; + spdlog::info("Throughput: {} cities/sec", throughput); + spdlog::info("=======================================\n"); } diff --git a/pipeline/src/json_handling/stream_parser.cpp b/pipeline/src/json_handling/stream_parser.cpp index 1108f74..68dcf0e 100644 --- a/pipeline/src/json_handling/stream_parser.cpp +++ b/pipeline/src/json_handling/stream_parser.cpp @@ -1,288 +1,289 @@ -#include -#include +#include "json_handling/stream_parser.h" + +#include #include #include -#include +#include +#include #include "database/database.h" -#include "json_handling/stream_parser.h" class CityRecordHandler { - friend class boost::json::basic_parser; + friend class boost::json::basic_parser; -public: - static constexpr std::size_t max_array_size = static_cast(-1); - static constexpr std::size_t max_object_size = static_cast(-1); - static constexpr std::size_t max_string_size = static_cast(-1); - static constexpr std::size_t max_key_size = static_cast(-1); + public: + static constexpr std::size_t max_array_size = static_cast(-1); + static constexpr std::size_t max_object_size = static_cast(-1); + static constexpr std::size_t max_string_size = static_cast(-1); + static constexpr std::size_t max_key_size = static_cast(-1); - struct ParseContext { - SqliteDatabase *db = nullptr; - std::function on_city; - std::function on_progress; - size_t cities_emitted = 0; - size_t total_file_size = 0; - int countries_inserted = 0; - int states_inserted = 0; - }; + struct ParseContext { + SqliteDatabase* db = nullptr; + std::function on_city; + std::function on_progress; + size_t cities_emitted = 0; + size_t total_file_size = 0; + int countries_inserted = 0; + int states_inserted = 0; + }; - explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {} + explicit CityRecordHandler(ParseContext& ctx) : context(ctx) {} -private: - ParseContext &context; + private: + ParseContext& context; - int depth = 0; - bool in_countries_array = false; - bool in_country_object = false; - bool in_states_array = false; - bool in_state_object = false; - bool in_cities_array = false; - bool building_city = false; + int depth = 0; + bool in_countries_array = false; + bool in_country_object = false; + bool in_states_array = false; + bool in_state_object = false; + bool in_cities_array = false; + bool building_city = false; - int current_country_id = 0; - int current_state_id = 0; - CityRecord current_city = {}; - std::string current_key; - std::string current_key_val; - std::string current_string_val; + int current_country_id = 0; + int current_state_id = 0; + CityRecord current_city = {}; + std::string current_key; + std::string current_key_val; + std::string current_string_val; - std::string country_info[3]; - std::string state_info[2]; + std::string country_info[3]; + std::string state_info[2]; - // Boost.JSON SAX Hooks - bool on_document_begin(boost::system::error_code &) { return true; } - bool on_document_end(boost::system::error_code &) { return true; } + // Boost.JSON SAX Hooks + bool on_document_begin(boost::system::error_code&) { return true; } + bool on_document_end(boost::system::error_code&) { return true; } - bool on_array_begin(boost::system::error_code &) { - depth++; - if (depth == 1) { - in_countries_array = true; - } else if (depth == 3 && current_key == "states") { - in_states_array = true; - } else if (depth == 5 && current_key == "cities") { - in_cities_array = true; - } - return true; - } - - bool on_array_end(std::size_t, boost::system::error_code &) { - if (depth == 1) { - in_countries_array = false; - } else if (depth == 3) { - in_states_array = false; - } else if (depth == 5) { - in_cities_array = false; - } - depth--; - return true; - } - - bool on_object_begin(boost::system::error_code &) { - depth++; - if (depth == 2 && in_countries_array) { - in_country_object = true; - current_country_id = 0; - country_info[0].clear(); - country_info[1].clear(); - country_info[2].clear(); - } else if (depth == 4 && in_states_array) { - in_state_object = true; - current_state_id = 0; - state_info[0].clear(); - state_info[1].clear(); - } else if (depth == 6 && in_cities_array) { - building_city = true; - current_city = {}; - } - return true; - } - - bool on_object_end(std::size_t, boost::system::error_code &) { - if (depth == 6 && building_city) { - if (current_city.id > 0 && current_state_id > 0 && - current_country_id > 0) { - current_city.state_id = current_state_id; - current_city.country_id = current_country_id; - - try { - context.on_city(current_city); - context.cities_emitted++; - - if (context.on_progress && context.cities_emitted % 10000 == 0) { - context.on_progress(context.cities_emitted, - context.total_file_size); - } - } catch (const std::exception &e) { - spdlog::warn("Record parsing failed: {}", e.what()); - } + bool on_array_begin(boost::system::error_code&) { + depth++; + if (depth == 1) { + in_countries_array = true; + } else if (depth == 3 && current_key == "states") { + in_states_array = true; + } else if (depth == 5 && current_key == "cities") { + in_cities_array = true; } - building_city = false; - } else if (depth == 4 && in_state_object) { - if (current_state_id > 0 && current_country_id > 0) { - try { - context.db->InsertState(current_state_id, current_country_id, - state_info[0], state_info[1]); - context.states_inserted++; - } catch (const std::exception &e) { - spdlog::warn("Record parsing failed: {}", e.what()); - } + return true; + } + + bool on_array_end(std::size_t, boost::system::error_code&) { + if (depth == 1) { + in_countries_array = false; + } else if (depth == 3) { + in_states_array = false; + } else if (depth == 5) { + in_cities_array = false; } - in_state_object = false; - } else if (depth == 2 && in_country_object) { - if (current_country_id > 0) { - try { - context.db->InsertCountry(current_country_id, country_info[0], - country_info[1], country_info[2]); - context.countries_inserted++; - } catch (const std::exception &e) { - spdlog::warn("Record parsing failed: {}", e.what()); - } + depth--; + return true; + } + + bool on_object_begin(boost::system::error_code&) { + depth++; + if (depth == 2 && in_countries_array) { + in_country_object = true; + current_country_id = 0; + country_info[0].clear(); + country_info[1].clear(); + country_info[2].clear(); + } else if (depth == 4 && in_states_array) { + in_state_object = true; + current_state_id = 0; + state_info[0].clear(); + state_info[1].clear(); + } else if (depth == 6 && in_cities_array) { + building_city = true; + current_city = {}; } - in_country_object = false; - } + return true; + } - depth--; - return true; - } + bool on_object_end(std::size_t, boost::system::error_code&) { + if (depth == 6 && building_city) { + if (current_city.id > 0 && current_state_id > 0 && + current_country_id > 0) { + current_city.state_id = current_state_id; + current_city.country_id = current_country_id; - bool on_key_part(boost::json::string_view s, std::size_t, - boost::system::error_code &) { - current_key_val.append(s.data(), s.size()); - return true; - } + try { + context.on_city(current_city); + context.cities_emitted++; - bool on_key(boost::json::string_view s, std::size_t, - boost::system::error_code &) { - current_key_val.append(s.data(), s.size()); - current_key = current_key_val; - current_key_val.clear(); - return true; - } - - bool on_string_part(boost::json::string_view s, std::size_t, - boost::system::error_code &) { - current_string_val.append(s.data(), s.size()); - return true; - } - - bool on_string(boost::json::string_view s, std::size_t, - boost::system::error_code &) { - current_string_val.append(s.data(), s.size()); - - if (building_city && current_key == "name") { - current_city.name = current_string_val; - } else if (in_state_object && current_key == "name") { - state_info[0] = current_string_val; - } else if (in_state_object && current_key == "iso2") { - state_info[1] = current_string_val; - } else if (in_country_object && current_key == "name") { - country_info[0] = current_string_val; - } else if (in_country_object && current_key == "iso2") { - country_info[1] = current_string_val; - } else if (in_country_object && current_key == "iso3") { - country_info[2] = current_string_val; - } - - current_string_val.clear(); - return true; - } - - bool on_number_part(boost::json::string_view, boost::system::error_code &) { - return true; - } - - bool on_int64(int64_t i, boost::json::string_view, - boost::system::error_code &) { - if (building_city && current_key == "id") { - current_city.id = static_cast(i); - } else if (in_state_object && current_key == "id") { - current_state_id = static_cast(i); - } else if (in_country_object && current_key == "id") { - current_country_id = static_cast(i); - } - return true; - } - - bool on_uint64(uint64_t u, boost::json::string_view, - boost::system::error_code &ec) { - return on_int64(static_cast(u), "", ec); - } - - bool on_double(double d, boost::json::string_view, - boost::system::error_code &) { - if (building_city) { - if (current_key == "latitude") { - current_city.latitude = d; - } else if (current_key == "longitude") { - current_city.longitude = d; + if (context.on_progress && context.cities_emitted % 10000 == 0) { + context.on_progress(context.cities_emitted, + context.total_file_size); + } + } catch (const std::exception& e) { + spdlog::warn("Record parsing failed: {}", e.what()); + } + } + building_city = false; + } else if (depth == 4 && in_state_object) { + if (current_state_id > 0 && current_country_id > 0) { + try { + context.db->InsertState(current_state_id, current_country_id, + state_info[0], state_info[1]); + context.states_inserted++; + } catch (const std::exception& e) { + spdlog::warn("Record parsing failed: {}", e.what()); + } + } + in_state_object = false; + } else if (depth == 2 && in_country_object) { + if (current_country_id > 0) { + try { + context.db->InsertCountry(current_country_id, country_info[0], + country_info[1], country_info[2]); + context.countries_inserted++; + } catch (const std::exception& e) { + spdlog::warn("Record parsing failed: {}", e.what()); + } + } + in_country_object = false; } - } - return true; - } - bool on_bool(bool, boost::system::error_code &) { return true; } - bool on_null(boost::system::error_code &) { return true; } - bool on_comment_part(boost::json::string_view, boost::system::error_code &) { - return true; - } - bool on_comment(boost::json::string_view, boost::system::error_code &) { - return true; - } + depth--; + return true; + } + + bool on_key_part(boost::json::string_view s, std::size_t, + boost::system::error_code&) { + current_key_val.append(s.data(), s.size()); + return true; + } + + bool on_key(boost::json::string_view s, std::size_t, + boost::system::error_code&) { + current_key_val.append(s.data(), s.size()); + current_key = current_key_val; + current_key_val.clear(); + return true; + } + + bool on_string_part(boost::json::string_view s, std::size_t, + boost::system::error_code&) { + current_string_val.append(s.data(), s.size()); + return true; + } + + bool on_string(boost::json::string_view s, std::size_t, + boost::system::error_code&) { + current_string_val.append(s.data(), s.size()); + + if (building_city && current_key == "name") { + current_city.name = current_string_val; + } else if (in_state_object && current_key == "name") { + state_info[0] = current_string_val; + } else if (in_state_object && current_key == "iso2") { + state_info[1] = current_string_val; + } else if (in_country_object && current_key == "name") { + country_info[0] = current_string_val; + } else if (in_country_object && current_key == "iso2") { + country_info[1] = current_string_val; + } else if (in_country_object && current_key == "iso3") { + country_info[2] = current_string_val; + } + + current_string_val.clear(); + return true; + } + + bool on_number_part(boost::json::string_view, boost::system::error_code&) { + return true; + } + + bool on_int64(int64_t i, boost::json::string_view, + boost::system::error_code&) { + if (building_city && current_key == "id") { + current_city.id = static_cast(i); + } else if (in_state_object && current_key == "id") { + current_state_id = static_cast(i); + } else if (in_country_object && current_key == "id") { + current_country_id = static_cast(i); + } + return true; + } + + bool on_uint64(uint64_t u, boost::json::string_view, + boost::system::error_code& ec) { + return on_int64(static_cast(u), "", ec); + } + + bool on_double(double d, boost::json::string_view, + boost::system::error_code&) { + if (building_city) { + if (current_key == "latitude") { + current_city.latitude = d; + } else if (current_key == "longitude") { + current_city.longitude = d; + } + } + return true; + } + + bool on_bool(bool, boost::system::error_code&) { return true; } + bool on_null(boost::system::error_code&) { return true; } + bool on_comment_part(boost::json::string_view, boost::system::error_code&) { + return true; + } + bool on_comment(boost::json::string_view, boost::system::error_code&) { + return true; + } }; void StreamingJsonParser::Parse( - const std::string &file_path, SqliteDatabase &db, - std::function on_city, + const std::string& file_path, SqliteDatabase& db, + std::function on_city, std::function on_progress) { + spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); - spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); + FILE* file = std::fopen(file_path.c_str(), "rb"); + if (!file) { + throw std::runtime_error("Failed to open JSON file: " + file_path); + } - FILE *file = std::fopen(file_path.c_str(), "rb"); - if (!file) { - throw std::runtime_error("Failed to open JSON file: " + file_path); - } - - size_t total_size = 0; - if (std::fseek(file, 0, SEEK_END) == 0) { - long file_size = std::ftell(file); - if (file_size > 0) { - total_size = static_cast(file_size); - } - std::rewind(file); - } - - CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, - total_size, 0, 0}; - boost::json::basic_parser parser( - boost::json::parse_options{}, ctx); - - char buf[65536]; - size_t bytes_read; - boost::system::error_code ec; - - while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) { - char const *p = buf; - std::size_t remain = bytes_read; - - while (remain > 0) { - std::size_t consumed = parser.write_some(true, p, remain, ec); - if (ec) { - std::fclose(file); - throw std::runtime_error("JSON parse error: " + ec.message()); + size_t total_size = 0; + if (std::fseek(file, 0, SEEK_END) == 0) { + long file_size = std::ftell(file); + if (file_size > 0) { + total_size = static_cast(file_size); } - p += consumed; - remain -= consumed; - } - } + std::rewind(file); + } - parser.write_some(false, nullptr, 0, ec); // Signal EOF - std::fclose(file); + CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, total_size, + 0, 0}; + boost::json::basic_parser parser( + boost::json::parse_options{}, ctx); - if (ec) { - throw std::runtime_error("JSON parse error at EOF: " + ec.message()); - } + char buf[65536]; + size_t bytes_read; + boost::system::error_code ec; - spdlog::info(" OK: Parsed {} countries, {} states, {} cities", - ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); + while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) { + char const* p = buf; + std::size_t remain = bytes_read; + + while (remain > 0) { + std::size_t consumed = parser.write_some(true, p, remain, ec); + if (ec) { + std::fclose(file); + throw std::runtime_error("JSON parse error: " + ec.message()); + } + p += consumed; + remain -= consumed; + } + } + + parser.write_some(false, nullptr, 0, ec); // Signal EOF + std::fclose(file); + + if (ec) { + throw std::runtime_error("JSON parse error at EOF: " + ec.message()); + } + + spdlog::info(" OK: Parsed {} countries, {} states, {} cities", + ctx.countries_inserted, ctx.states_inserted, + ctx.cities_emitted); } diff --git a/pipeline/src/web_client/curl_web_client.cpp b/pipeline/src/web_client/curl_web_client.cpp index eba5a47..2622f80 100644 --- a/pipeline/src/web_client/curl_web_client.cpp +++ b/pipeline/src/web_client/curl_web_client.cpp @@ -1,139 +1,141 @@ #include "web_client/curl_web_client.h" -#include + #include + +#include #include #include #include #include CurlGlobalState::CurlGlobalState() { - if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) { - throw std::runtime_error( - "[CURLWebClient] Failed to initialize libcurl globally"); - } + if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) { + throw std::runtime_error( + "[CURLWebClient] Failed to initialize libcurl globally"); + } } CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); } namespace { // curl write callback that appends response data into a std::string -size_t WriteCallbackString(void *contents, size_t size, size_t nmemb, - void *userp) { - size_t realsize = size * nmemb; - auto *s = static_cast(userp); - s->append(static_cast(contents), realsize); - return realsize; +size_t WriteCallbackString(void* contents, size_t size, size_t nmemb, + void* userp) { + size_t realsize = size * nmemb; + auto* s = static_cast(userp); + s->append(static_cast(contents), realsize); + return realsize; } // curl write callback that writes to a file stream - size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb, - void *userp) { - size_t realsize = size * nmemb; - auto *outFile = static_cast(userp); - outFile->write(static_cast(contents), realsize); - return realsize; +size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb, + void* userp) { + size_t realsize = size * nmemb; + auto* outFile = static_cast(userp); + outFile->write(static_cast(contents), realsize); + return realsize; } // RAII wrapper for CURL handle using unique_ptr using CurlHandle = std::unique_ptr; CurlHandle create_handle() { - CURL *handle = curl_easy_init(); - if (!handle) { - throw std::runtime_error( - "[CURLWebClient] Failed to initialize libcurl handle"); - } - return CurlHandle(handle, &curl_easy_cleanup); + CURL* handle = curl_easy_init(); + if (!handle) { + throw std::runtime_error( + "[CURLWebClient] Failed to initialize libcurl handle"); + } + return CurlHandle(handle, &curl_easy_cleanup); } -void set_common_get_options(CURL *curl, const std::string &url, +void set_common_get_options(CURL* curl, const std::string& url, long connect_timeout, long total_timeout) { - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); - curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout); - curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout); - curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout); + curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); } -} // namespace +} // namespace CURLWebClient::CURLWebClient() {} CURLWebClient::~CURLWebClient() {} -void CURLWebClient::DownloadToFile(const std::string &url, - const std::string &file_path) { - auto curl = create_handle(); +void CURLWebClient::DownloadToFile(const std::string& url, + const std::string& file_path) { + auto curl = create_handle(); - std::ofstream outFile(file_path, std::ios::binary); - if (!outFile.is_open()) { - throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + - file_path); - } + std::ofstream outFile(file_path, std::ios::binary); + if (!outFile.is_open()) { + throw std::runtime_error( + "[CURLWebClient] Cannot open file for writing: " + file_path); + } - set_common_get_options(curl.get(), url, 30L, 300L); - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, - static_cast(&outFile)); + set_common_get_options(curl.get(), url, 30L, 300L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, + static_cast(&outFile)); - CURLcode res = curl_easy_perform(curl.get()); - outFile.close(); + CURLcode res = curl_easy_perform(curl.get()); + outFile.close(); - if (res != CURLE_OK) { - std::remove(file_path.c_str()); - std::string error = std::string("[CURLWebClient] Download failed: ") + - curl_easy_strerror(res); - throw std::runtime_error(error); - } + if (res != CURLE_OK) { + std::remove(file_path.c_str()); + std::string error = std::string("[CURLWebClient] Download failed: ") + + curl_easy_strerror(res); + throw std::runtime_error(error); + } - long httpCode = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); + long httpCode = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); - if (httpCode != 200) { - std::remove(file_path.c_str()); - std::stringstream ss; - ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; - throw std::runtime_error(ss.str()); - } + if (httpCode != 200) { + std::remove(file_path.c_str()); + std::stringstream ss; + ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; + throw std::runtime_error(ss.str()); + } } -std::string CURLWebClient::Get(const std::string &url) { - auto curl = create_handle(); +std::string CURLWebClient::Get(const std::string& url) { + auto curl = create_handle(); - std::string response_string; - set_common_get_options(curl.get(), url, 10L, 20L); - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); + std::string response_string; + set_common_get_options(curl.get(), url, 10L, 20L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); - CURLcode res = curl_easy_perform(curl.get()); + CURLcode res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - std::string error = - std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); - throw std::runtime_error(error); - } + if (res != CURLE_OK) { + std::string error = + std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); + throw std::runtime_error(error); + } - long httpCode = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); + long httpCode = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); - if (httpCode != 200) { - std::stringstream ss; - ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; - throw std::runtime_error(ss.str()); - } + if (httpCode != 200) { + std::stringstream ss; + ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; + throw std::runtime_error(ss.str()); + } - return response_string; + return response_string; } -std::string CURLWebClient::UrlEncode(const std::string &value) { - // A NULL handle is fine for UTF-8 encoding according to libcurl docs. - char *output = curl_easy_escape(nullptr, value.c_str(), 0); +std::string CURLWebClient::UrlEncode(const std::string& value) { + // A NULL handle is fine for UTF-8 encoding according to libcurl docs. + char* output = curl_easy_escape(nullptr, value.c_str(), 0); - if (output) { - std::string result(output); - curl_free(output); - return result; - } - throw std::runtime_error("[CURLWebClient] curl_easy_escape failed"); + if (output) { + std::string result(output); + curl_free(output); + return result; + } + throw std::runtime_error("[CURLWebClient] curl_easy_escape failed"); } diff --git a/pipeline/src/wikipedia/wikipedia_service.cpp b/pipeline/src/wikipedia/wikipedia_service.cpp index c42bf27..ef851bb 100644 --- a/pipeline/src/wikipedia/wikipedia_service.cpp +++ b/pipeline/src/wikipedia/wikipedia_service.cpp @@ -1,77 +1,78 @@ #include "wikipedia/wikipedia_service.h" -#include + #include +#include + WikipediaService::WikipediaService(std::shared_ptr client) : client_(std::move(client)) {} std::string WikipediaService::FetchExtract(std::string_view query) { - const std::string encoded = client_->UrlEncode(std::string(query)); - const std::string url = - "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + - "&prop=extracts&explaintext=true&format=json"; + const std::string encoded = client_->UrlEncode(std::string(query)); + const std::string url = + "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + + "&prop=extracts&explaintext=true&format=json"; - const std::string body = client_->Get(url); + const std::string body = client_->Get(url); - boost::system::error_code ec; - boost::json::value doc = boost::json::parse(body, ec); + boost::system::error_code ec; + boost::json::value doc = boost::json::parse(body, ec); - if (!ec && doc.is_object()) { - auto &pages = doc.at("query").at("pages").get_object(); - if (!pages.empty()) { - auto &page = pages.begin()->value().get_object(); - if (page.contains("extract") && page.at("extract").is_string()) { - std::string extract(page.at("extract").as_string().c_str()); - spdlog::debug("WikipediaService fetched {} chars for '{}'", - extract.size(), query); - return extract; + if (!ec && doc.is_object()) { + auto& pages = doc.at("query").at("pages").get_object(); + if (!pages.empty()) { + auto& page = pages.begin()->value().get_object(); + if (page.contains("extract") && page.at("extract").is_string()) { + std::string extract(page.at("extract").as_string().c_str()); + spdlog::debug("WikipediaService fetched {} chars for '{}'", + extract.size(), query); + return extract; + } } - } - } + } - return {}; + return {}; } std::string WikipediaService::GetSummary(std::string_view city, std::string_view country) { - const std::string key = std::string(city) + "|" + std::string(country); - const auto cacheIt = cache_.find(key); - if (cacheIt != cache_.end()) { - return cacheIt->second; - } + const std::string key = std::string(city) + "|" + std::string(country); + const auto cacheIt = cache_.find(key); + if (cacheIt != cache_.end()) { + return cacheIt->second; + } - std::string result; + std::string result; - if (!client_) { - cache_.emplace(key, result); - return result; - } + if (!client_) { + cache_.emplace(key, result); + return result; + } - std::string regionQuery(city); - if (!country.empty()) { - regionQuery += ", "; - regionQuery += country; - } + std::string regionQuery(city); + if (!country.empty()) { + regionQuery += ", "; + regionQuery += country; + } - const std::string beerQuery = "beer in " + std::string(city); + const std::string beerQuery = "beer in " + std::string(city); - try { - const std::string regionExtract = FetchExtract(regionQuery); - const std::string beerExtract = FetchExtract(beerQuery); + try { + const std::string regionExtract = FetchExtract(regionQuery); + const std::string beerExtract = FetchExtract(beerQuery); - if (!regionExtract.empty()) { - result += regionExtract; - } - if (!beerExtract.empty()) { - if (!result.empty()) - result += "\n\n"; - result += beerExtract; - } - } catch (const std::runtime_error &e) { - spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery, - e.what()); - } + if (!regionExtract.empty()) { + result += regionExtract; + } + if (!beerExtract.empty()) { + if (!result.empty()) result += "\n\n"; + result += beerExtract; + } + } catch (const std::runtime_error& e) { + spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery, + e.what()); + } - cache_.emplace(key, result); - return result; + cache_.emplace(key, result); + return result; }